26template <
typename ThreadGroup,
31 typename ElementwiseOperation,
33 typename SliceLengths,
34 typename ThreadClusterLengths,
35 typename ThreadClusterArrangeOrder,
36 typename DimAccessOrder,
39 typename ThreadTransferSrcResetCoordinateAfterRunFlags,
40 typename ThreadTransferDstResetCoordinateAfterRunFlags>
54 const SrcDescs& src_descs,
56 const DstDescs& dst_descs,
58 const ElementwiseOperation& element_op)
59 : threadwise_transfer_(src_descs,
65 static_assert(
nSrc == SrcDatas::Size() &&
nSrc == SrcDescs::Size() &&
66 nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
67 nDst == DstDatas::Size() &&
nDst == DstDescs::Size() &&
68 nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(),
71 static_for<0, nSrc, 1>{}([&](
auto i) {
77 static_for<0, nDst, 1>{}([&](
auto i) {
83 static_assert(
nDim == ThreadClusterLengths::Size() &&
84 nDim == ThreadClusterArrangeOrder::Size() &&
85 nDim == DimAccessOrder::Size(),
86 "wrong! nDim not consistent");
90 "wrong! threads should be mapped to cover entire slicing window");
92 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
93 "wrong! ThreadGroup::GetNumOfThread() too small");
95 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
96 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
98 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
104 [&](
auto i) {
return src_block_slice_origins[i] + thread_data_idx_begin; },
108 [&](
auto i) {
return dst_block_slice_origins[i] + thread_data_idx_begin; },
111 threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
112 threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
116 template <
typename SrcBuffers,
typename DstBuffers>
117 __device__
void Run(
const SrcDescs& src_descs,
118 const SrcBuffers& src_bufs,
119 const DstDescs& dst_descs,
122 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
123 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
125 threadwise_transfer_.Run(src_descs, src_bufs, dst_descs, dst_bufs);
129 template <index_t ISrc>
133 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
134 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
136 threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
140 template <index_t IDst>
144 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
145 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
147 threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
152 static constexpr auto thread_cluster_desc_ =
155 using ThreadwiseTransfer =
156 ThreadwiseTensorSliceTransfer_v7<SrcDatas,
160 ElementwiseOperation,
166 ThreadTransferSrcResetCoordinateAfterRunFlags,
167 ThreadTransferDstResetCoordinateAfterRunFlags>;
169 ThreadwiseTransfer threadwise_transfer_;
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
static constexpr index_t nDst
Definition thread_group_tensor_slice_transfer_v7.hpp:47
static constexpr index_t nSrc
Definition thread_group_tensor_slice_transfer_v7.hpp:46
__device__ constexpr ThreadGroupTensorSliceTransfer_v7(const SrcDescs &src_descs, const StaticallyIndexedArray< Index, nSrc > &src_block_slice_origins, const DstDescs &dst_descs, const StaticallyIndexedArray< Index, nDst > &dst_block_slice_origins, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_v7.hpp:53
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v7.hpp:49
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v7.hpp:51
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v7.hpp:43
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, Number< ISrc > iSrc, const Index &step)
Definition thread_group_tensor_slice_transfer_v7.hpp:131
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, Number< IDst > iDst, const Index &step)
Definition thread_group_tensor_slice_transfer_v7.hpp:142
__device__ void Run(const SrcDescs &src_descs, const SrcBuffers &src_bufs, const DstDescs &dst_descs, DstBuffers dst_bufs)
Definition thread_group_tensor_slice_transfer_v7.hpp:117