32#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
33#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
36#if !defined(__HIPCC_RTC__)
40namespace cooperative_groups {
62 __CG_QUALIFIER__ thread_group(internal::group_type type, uint32_t size =
static_cast<uint64_t
>(0),
63 uint64_t mask =
static_cast<uint64_t
>(0)) {
72 unsigned int meta_group_rank;
73 unsigned int meta_group_size;
76 struct _coalesced_info {
77 lane_mask member_mask;
79 struct _tiled_info tiled_info;
82 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
83 unsigned int tile_size);
84 friend class thread_block;
90 __CG_QUALIFIER__ uint32_t size()
const {
return _size; }
91 __CG_QUALIFIER__
unsigned int cg_type()
const {
return _type; }
93 __CG_QUALIFIER__ uint32_t thread_rank()
const;
95 __CG_QUALIFIER__
bool is_valid()
const;
97 __CG_QUALIFIER__
void sync()
const;
122class multi_grid_group :
public thread_group {
125 friend __CG_QUALIFIER__ multi_grid_group this_multi_grid();
129 explicit __CG_QUALIFIER__ multi_grid_group(uint32_t size)
130 : thread_group(internal::cg_multi_grid, size) {}
135 __CG_QUALIFIER__ uint32_t num_grids() {
return internal::multi_grid::num_grids(); }
138 __CG_QUALIFIER__ uint32_t grid_rank() {
return internal::multi_grid::grid_rank(); }
139 __CG_QUALIFIER__ uint32_t thread_rank()
const {
return internal::multi_grid::thread_rank(); }
140 __CG_QUALIFIER__
bool is_valid()
const {
return internal::multi_grid::is_valid(); }
141 __CG_QUALIFIER__
void sync()
const { internal::multi_grid::sync(); }
153__CG_QUALIFIER__ multi_grid_group this_multi_grid() {
154 return multi_grid_group(internal::multi_grid::size());
165class grid_group :
public thread_group {
168 friend __CG_QUALIFIER__ grid_group this_grid();
172 explicit __CG_QUALIFIER__ grid_group(uint32_t size) : thread_group(internal::cg_grid, size) {}
175 __CG_QUALIFIER__ uint32_t thread_rank()
const {
return internal::grid::thread_rank(); }
176 __CG_QUALIFIER__
bool is_valid()
const {
return internal::grid::is_valid(); }
177 __CG_QUALIFIER__
void sync()
const { internal::grid::sync(); }
189__CG_QUALIFIER__ grid_group this_grid() {
return grid_group(internal::grid::size()); }
200class thread_block :
public thread_group {
203 friend __CG_QUALIFIER__ thread_block this_thread_block();
204 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
205 unsigned int tile_size);
206 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_block& parent,
207 unsigned int tile_size);
210 explicit __CG_QUALIFIER__ thread_block(uint32_t size)
211 : thread_group(internal::cg_workgroup, size) {}
213 __CG_QUALIFIER__ thread_group new_tiled_group(
unsigned int tile_size)
const {
214 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
216 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
217 __hip_assert(
false &&
"invalid tile size");
220 auto block_size = size();
221 auto rank = thread_rank();
222 auto partitions = (block_size + tile_size - 1) / tile_size;
223 auto tail = (partitions * tile_size) - block_size;
224 auto partition_size = tile_size - tail * (rank >= (partitions - 1) * tile_size);
225 thread_group tiledGroup = thread_group(internal::cg_tiled_group, partition_size);
227 tiledGroup.coalesced_info.tiled_info.size = tile_size;
228 tiledGroup.coalesced_info.tiled_info.is_tiled =
true;
229 tiledGroup.coalesced_info.tiled_info.meta_group_rank = rank / tile_size;
230 tiledGroup.coalesced_info.tiled_info.meta_group_size = partitions;
236 __CG_STATIC_QUALIFIER__ dim3 group_index() {
return internal::workgroup::group_index(); }
238 __CG_STATIC_QUALIFIER__ dim3 thread_index() {
return internal::workgroup::thread_index(); }
239 __CG_STATIC_QUALIFIER__ uint32_t thread_rank() {
return internal::workgroup::thread_rank(); }
240 __CG_STATIC_QUALIFIER__ uint32_t size() {
return internal::workgroup::size(); }
241 __CG_STATIC_QUALIFIER__
bool is_valid() {
return internal::workgroup::is_valid(); }
242 __CG_STATIC_QUALIFIER__
void sync() { internal::workgroup::sync(); }
243 __CG_QUALIFIER__ dim3 group_dim() {
return internal::workgroup::block_dim(); }
255__CG_QUALIFIER__ thread_block this_thread_block() {
256 return thread_block(internal::workgroup::size());
267class tiled_group :
public thread_group {
269 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
270 unsigned int tile_size);
271 friend __CG_QUALIFIER__ tiled_group tiled_partition(
const tiled_group& parent,
272 unsigned int tile_size);
274 __CG_QUALIFIER__ tiled_group new_tiled_group(
unsigned int tile_size)
const {
275 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
277 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
278 __hip_assert(
false &&
"invalid tile size");
281 if (size() <= tile_size) {
285 tiled_group tiledGroup = tiled_group(tile_size);
286 tiledGroup.coalesced_info.tiled_info.is_tiled =
true;
291 explicit __CG_QUALIFIER__ tiled_group(
unsigned int tileSize)
292 : thread_group(internal::cg_tiled_group, tileSize) {
293 coalesced_info.tiled_info.size = tileSize;
294 coalesced_info.tiled_info.is_tiled =
true;
298 __CG_QUALIFIER__
unsigned int size()
const {
return (coalesced_info.tiled_info.size); }
300 __CG_QUALIFIER__
unsigned int thread_rank()
const {
301 return (internal::workgroup::thread_rank() & (coalesced_info.tiled_info.size - 1));
304 __CG_QUALIFIER__
void sync()
const {
305 internal::tiled_group::sync();
316class coalesced_group :
public thread_group {
318 friend __CG_QUALIFIER__ coalesced_group coalesced_threads();
319 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
unsigned int tile_size);
320 friend __CG_QUALIFIER__ coalesced_group tiled_partition(
const coalesced_group& parent,
unsigned int tile_size);
322 __CG_QUALIFIER__ coalesced_group new_tiled_group(
unsigned int tile_size)
const {
323 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
325 if (!tile_size || (tile_size > size()) || !pow2) {
326 return coalesced_group(0);
331 if (coalesced_info.tiled_info.is_tiled) {
332 unsigned int base_offset = (thread_rank() & (~(tile_size - 1)));
333 unsigned int masklength = min(
static_cast<unsigned int>(size()) - base_offset, tile_size);
334 lane_mask member_mask =
static_cast<lane_mask
>(-1) >> (__AMDGCN_WAVEFRONT_SIZE - masklength);
336 member_mask <<= (__lane_id() & ~(tile_size - 1));
337 coalesced_group coalesced_tile = coalesced_group(member_mask);
338 coalesced_tile.coalesced_info.tiled_info.is_tiled =
true;
339 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
340 coalesced_tile.coalesced_info.tiled_info.meta_group_size = size() / tile_size;
341 return coalesced_tile;
345 lane_mask member_mask = 0;
346 unsigned int tile_rank = 0;
347 int lanes_to_skip = ((thread_rank()) / tile_size) * tile_size;
349 for (
unsigned int i = 0; i < __AMDGCN_WAVEFRONT_SIZE; i++) {
350 lane_mask active = coalesced_info.member_mask & (1 << i);
353 if (lanes_to_skip <= 0 && tile_rank < tile_size) {
355 member_mask |= active;
361 coalesced_group coalesced_tile = coalesced_group(member_mask);
362 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
363 coalesced_tile.coalesced_info.tiled_info.meta_group_size =
364 (size() + tile_size - 1) / tile_size;
365 return coalesced_tile;
367 return coalesced_group(0);
372 explicit __CG_QUALIFIER__ coalesced_group(lane_mask member_mask)
373 : thread_group(internal::cg_coalesced_group) {
374 coalesced_info.member_mask = member_mask;
375 coalesced_info.size = __popcll(coalesced_info.member_mask);
376 coalesced_info.tiled_info.is_tiled =
false;
377 coalesced_info.tiled_info.meta_group_rank = 0;
378 coalesced_info.tiled_info.meta_group_size = 1;
382 __CG_QUALIFIER__
unsigned int size()
const {
383 return coalesced_info.size;
386 __CG_QUALIFIER__
unsigned int thread_rank()
const {
387 return internal::coalesced_group::masked_bit_count(coalesced_info.member_mask);
390 __CG_QUALIFIER__
void sync()
const {
391 internal::coalesced_group::sync();
394 __CG_QUALIFIER__
unsigned int meta_group_rank()
const {
395 return coalesced_info.tiled_info.meta_group_rank;
398 __CG_QUALIFIER__
unsigned int meta_group_size()
const {
399 return coalesced_info.tiled_info.meta_group_size;
403 __CG_QUALIFIER__ T shfl(T var,
int srcRank)
const {
404 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
406 srcRank = srcRank %
static_cast<int>(size());
408 int lane = (size() == __AMDGCN_WAVEFRONT_SIZE) ? srcRank
409 : (__AMDGCN_WAVEFRONT_SIZE == 64) ? __fns64(coalesced_info.member_mask, 0, (srcRank + 1))
410 : __fns32(coalesced_info.member_mask, 0, (srcRank + 1));
412 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
416 __CG_QUALIFIER__ T shfl_down(T var,
unsigned int lane_delta)
const {
417 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
423 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
424 return __shfl_down(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
428 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
429 lane = __fns64(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
432 lane = __fns32(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
439 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
443 __CG_QUALIFIER__ T shfl_up(T var,
unsigned int lane_delta)
const {
444 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
450 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
451 return __shfl_up(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
455 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
456 lane = __fns64(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
458 else if (__AMDGCN_WAVEFRONT_SIZE == 32) {
459 lane = __fns32(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
466 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
477__CG_QUALIFIER__ coalesced_group coalesced_threads() {
478 return cooperative_groups::coalesced_group(__builtin_amdgcn_read_exec());
486__CG_QUALIFIER__ uint32_t thread_group::thread_rank()
const {
487 switch (this->_type) {
488 case internal::cg_multi_grid: {
489 return (
static_cast<const multi_grid_group*
>(
this)->thread_rank());
491 case internal::cg_grid: {
492 return (
static_cast<const grid_group*
>(
this)->thread_rank());
494 case internal::cg_workgroup: {
495 return (
static_cast<const thread_block*
>(
this)->thread_rank());
497 case internal::cg_tiled_group: {
498 return (
static_cast<const tiled_group*
>(
this)->thread_rank());
500 case internal::cg_coalesced_group: {
501 return (
static_cast<const coalesced_group*
>(
this)->thread_rank());
504 __hip_assert(
false &&
"invalid cooperative group type");
514__CG_QUALIFIER__
bool thread_group::is_valid()
const {
515 switch (this->_type) {
516 case internal::cg_multi_grid: {
517 return (
static_cast<const multi_grid_group*
>(
this)->is_valid());
519 case internal::cg_grid: {
520 return (
static_cast<const grid_group*
>(
this)->is_valid());
522 case internal::cg_workgroup: {
523 return (
static_cast<const thread_block*
>(
this)->is_valid());
525 case internal::cg_tiled_group: {
526 return (
static_cast<const tiled_group*
>(
this)->is_valid());
528 case internal::cg_coalesced_group: {
529 return (
static_cast<const coalesced_group*
>(
this)->is_valid());
532 __hip_assert(
false &&
"invalid cooperative group type");
542__CG_QUALIFIER__
void thread_group::sync()
const {
543 switch (this->_type) {
544 case internal::cg_multi_grid: {
545 static_cast<const multi_grid_group*
>(
this)->sync();
548 case internal::cg_grid: {
549 static_cast<const grid_group*
>(
this)->sync();
552 case internal::cg_workgroup: {
553 static_cast<const thread_block*
>(
this)->sync();
556 case internal::cg_tiled_group: {
557 static_cast<const tiled_group*
>(
this)->sync();
560 case internal::cg_coalesced_group: {
561 static_cast<const coalesced_group*
>(
this)->sync();
565 __hip_assert(
false &&
"invalid cooperative group type");
576template <
class CGTy> __CG_QUALIFIER__ uint32_t group_size(CGTy
const& g) {
return g.size(); }
583template <
class CGTy> __CG_QUALIFIER__ uint32_t thread_rank(CGTy
const& g) {
584 return g.thread_rank();
592template <
class CGTy> __CG_QUALIFIER__
bool is_valid(CGTy
const& g) {
return g.is_valid(); }
599template <
class CGTy> __CG_QUALIFIER__
void sync(CGTy
const& g) { g.sync(); }
605template <
unsigned int tileSize>
class tile_base {
607 _CG_STATIC_CONST_DECL_
unsigned int numThreads = tileSize;
611 _CG_STATIC_CONST_DECL_
unsigned int thread_rank() {
612 return (internal::workgroup::thread_rank() & (numThreads - 1));
616 __CG_STATIC_QUALIFIER__
unsigned int size() {
return numThreads; }
623template <
unsigned int size>
class thread_block_tile_base :
public tile_base<size> {
624 static_assert(is_valid_tile_size<size>::value,
625 "Tile size is either not a power of 2 or greater than the wavefront size");
626 using tile_base<size>::numThreads;
629 __CG_STATIC_QUALIFIER__
void sync() {
630 internal::tiled_group::sync();
633 template <
class T> __CG_QUALIFIER__ T shfl(T var,
int srcRank)
const {
634 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
635 return (__shfl(var, srcRank, numThreads));
638 template <
class T> __CG_QUALIFIER__ T shfl_down(T var,
unsigned int lane_delta)
const {
639 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
640 return (__shfl_down(var, lane_delta, numThreads));
643 template <
class T> __CG_QUALIFIER__ T shfl_up(T var,
unsigned int lane_delta)
const {
644 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
645 return (__shfl_up(var, lane_delta, numThreads));
648 template <
class T> __CG_QUALIFIER__ T shfl_xor(T var,
unsigned int laneMask)
const {
649 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
650 return (__shfl_xor(var, laneMask, numThreads));
655template <
unsigned int tileSize,
typename ParentCGTy>
656class parent_group_info {
660 __CG_STATIC_QUALIFIER__
unsigned int meta_group_rank() {
661 return ParentCGTy::thread_rank() / tileSize;
665 __CG_STATIC_QUALIFIER__
unsigned int meta_group_size() {
666 return (ParentCGTy::size() + tileSize - 1) / tileSize;
676template <
unsigned int tileSize,
class ParentCGTy>
677class thread_block_tile_type :
public thread_block_tile_base<tileSize>,
679 public parent_group_info<tileSize, ParentCGTy> {
680 _CG_STATIC_CONST_DECL_
unsigned int numThreads = tileSize;
681 typedef thread_block_tile_base<numThreads> tbtBase;
683 __CG_QUALIFIER__ thread_block_tile_type() : tiled_group(numThreads) {
684 coalesced_info.tiled_info.size = numThreads;
685 coalesced_info.tiled_info.is_tiled =
true;
690 using tbtBase::thread_rank;
694template <
unsigned int tileSize>
695class thread_block_tile_type<tileSize, void> :
public thread_block_tile_base<tileSize>,
698 _CG_STATIC_CONST_DECL_
unsigned int numThreads = tileSize;
700 typedef thread_block_tile_base<numThreads> tbtBase;
704 __CG_QUALIFIER__ thread_block_tile_type(
unsigned int meta_group_rank,
unsigned int meta_group_size)
705 : tiled_group(numThreads) {
706 coalesced_info.tiled_info.size = numThreads;
707 coalesced_info.tiled_info.is_tiled =
true;
708 coalesced_info.tiled_info.meta_group_rank = meta_group_rank;
709 coalesced_info.tiled_info.meta_group_size = meta_group_size;
715 using tbtBase::thread_rank;
717 __CG_QUALIFIER__
unsigned int meta_group_rank()
const {
718 return coalesced_info.tiled_info.meta_group_rank;
721 __CG_QUALIFIER__
unsigned int meta_group_size()
const {
722 return coalesced_info.tiled_info.meta_group_size;
737__CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
unsigned int tile_size) {
738 if (parent.cg_type() == internal::cg_tiled_group) {
739 const tiled_group* cg =
static_cast<const tiled_group*
>(&parent);
740 return cg->new_tiled_group(tile_size);
742 else if(parent.cg_type() == internal::cg_coalesced_group) {
743 const coalesced_group* cg =
static_cast<const coalesced_group*
>(&parent);
744 return cg->new_tiled_group(tile_size);
747 const thread_block* tb =
static_cast<const thread_block*
>(&parent);
748 return tb->new_tiled_group(tile_size);
753__CG_QUALIFIER__ thread_group tiled_partition(
const thread_block& parent,
unsigned int tile_size) {
754 return (parent.new_tiled_group(tile_size));
757__CG_QUALIFIER__ tiled_group tiled_partition(
const tiled_group& parent,
unsigned int tile_size) {
758 return (parent.new_tiled_group(tile_size));
762__CG_QUALIFIER__ coalesced_group tiled_partition(
const coalesced_group& parent,
unsigned int tile_size) {
763 return (parent.new_tiled_group(tile_size));
766template <
unsigned int size,
class ParentCGTy>
class thread_block_tile;
769template <
unsigned int size,
class ParentCGTy>
class thread_block_tile_internal;
771template <
unsigned int size,
class ParentCGTy>
772class thread_block_tile_internal :
public thread_block_tile_type<size, ParentCGTy> {
774 template <
unsigned int tbtSize,
class tbtParentT>
775 __CG_QUALIFIER__ thread_block_tile_internal(
776 const thread_block_tile_internal<tbtSize, tbtParentT>& g)
777 : thread_block_tile_type<size, ParentCGTy>(g.meta_group_rank(), g.meta_group_size()) {}
779 __CG_QUALIFIER__ thread_block_tile_internal(
const thread_block& g)
780 : thread_block_tile_type<size, ParentCGTy>() {}
784template <
unsigned int size,
class ParentCGTy>
785class thread_block_tile :
public impl::thread_block_tile_internal<size, ParentCGTy> {
787 __CG_QUALIFIER__ thread_block_tile(
const ParentCGTy& g)
788 : impl::thread_block_tile_internal<size, ParentCGTy>(g) {}
791 __CG_QUALIFIER__
operator thread_block_tile<size, void>()
const {
792 return thread_block_tile<size, void>(*
this);
797template <
unsigned int size>
798class thread_block_tile<size, void> :
public impl::thread_block_tile_internal<size, void> {
799 template <
unsigned int,
class ParentCGTy>
friend class thread_block_tile;
803 template <
class ParentCGTy>
804 __CG_QUALIFIER__ thread_block_tile(
const thread_block_tile<size, ParentCGTy>& g)
805 : impl::thread_block_tile_internal<size, void>(g) {}
808template <
unsigned int size,
class ParentCGTy =
void>
class thread_block_tile;
811template <
unsigned int size,
class ParentCGTy>
struct tiled_partition_internal;
813template <
unsigned int size>
814struct tiled_partition_internal<size, thread_block> :
public thread_block_tile<size, thread_block> {
815 __CG_QUALIFIER__ tiled_partition_internal(
const thread_block& g)
816 : thread_block_tile<size, thread_block>(g) {}
826template <
unsigned int size,
class ParentCGTy>
827__CG_QUALIFIER__ thread_block_tile<size, ParentCGTy> tiled_partition(
const ParentCGTy& g) {
828 static_assert(is_valid_tile_size<size>::value,
829 "Tiled partition with size > wavefront size. Currently not supported ");
830 return impl::tiled_partition_internal<size, ParentCGTy>(g);
Device side implementation of cooperative group feature.