@@ -76,6 +76,34 @@ __device__ bool init_args(
76
76
return all_aligned;
77
77
}
78
78
79
+ template <
80
+ int depth,
81
+ typename param_type,
82
+ typename grad_type,
83
+ typename exp_avg_type,
84
+ typename exp_avg_sq_type>
85
+ __device__ bool init_args_mixed_prec (
86
+ param_type** param_args,
87
+ grad_type** grad_args,
88
+ exp_avg_type** exp_avg_args,
89
+ exp_avg_sq_type** exp_avg_sq_args,
90
+ FusedOptimizerTensorListMetadata<depth>& tl,
91
+ const int64_t chunk_idx,
92
+ const int64_t chunk_size,
93
+ const int64_t tensor_loc) {
94
+ *param_args =
95
+ (param_type*)tl.addresses [0 ][tensor_loc] + chunk_idx * chunk_size;
96
+ *grad_args = (grad_type*)tl.addresses [1 ][tensor_loc] + chunk_idx * chunk_size;
97
+ *exp_avg_args =
98
+ (exp_avg_type*)tl.addresses [2 ][tensor_loc] + chunk_idx * chunk_size;
99
+ *exp_avg_sq_args =
100
+ (exp_avg_sq_type*)tl.addresses [3 ][tensor_loc] + chunk_idx * chunk_size;
101
+
102
+ bool all_aligned = is_aligned (*param_args) && is_aligned (*grad_args) &&
103
+ is_aligned (*exp_avg_args) && is_aligned (*exp_avg_sq_args);
104
+ return all_aligned;
105
+ }
106
+
79
107
template <int depth, typename T>
80
108
__device__ void load_args (
81
109
T r_args[][kILP ],
@@ -95,6 +123,44 @@ __device__ void load_args(
95
123
}
96
124
}
97
125
126
+ template <
127
+ typename T,
128
+ typename param_type,
129
+ typename grad_type,
130
+ typename exp_avg_type,
131
+ typename exp_avg_sq_type>
132
+ __device__ void load_args (
133
+ T r_args[][kILP ],
134
+ const param_type* param_args,
135
+ const grad_type* grad_args,
136
+ const exp_avg_type* exp_avg_args,
137
+ const exp_avg_sq_type* exp_avg_sq_args,
138
+ const int64_t i_start,
139
+ const int64_t chunk_size,
140
+ const int64_t n) {
141
+ #pragma unroll
142
+ for (int ii = 0 ; ii < kILP ; ii++) {
143
+ // const auto i = i_start + threadIdx.x + ii * blockDim.x;
144
+ const auto i = i_start + threadIdx .x * kILP + ii;
145
+ r_args[0 ][ii] = 0 ;
146
+ if (i < n && i < chunk_size) {
147
+ r_args[0 ][ii] = static_cast <T>(param_args[i]);
148
+ }
149
+ r_args[1 ][ii] = 0 ;
150
+ if (i < n && i < chunk_size) {
151
+ r_args[1 ][ii] = static_cast <T>(grad_args[i]);
152
+ }
153
+ r_args[2 ][ii] = 0 ;
154
+ if (i < n && i < chunk_size) {
155
+ r_args[2 ][ii] = static_cast <T>(exp_avg_args[i]);
156
+ }
157
+ r_args[3 ][ii] = 0 ;
158
+ if (i < n && i < chunk_size) {
159
+ r_args[3 ][ii] = static_cast <T>(exp_avg_sq_args[i]);
160
+ }
161
+ }
162
+ }
163
+
98
164
template <typename T>
99
165
__device__ void store_args (
100
166
T* dst,
@@ -104,12 +170,29 @@ __device__ void store_args(
104
170
const int64_t n) {
105
171
#pragma unroll
106
172
for (int ii = 0 ; ii < kILP ; ii++) {
107
- const int64_t i = i_start + threadIdx .x + ii * blockDim .x ;
173
+ // const int64_t i = i_start + threadIdx.x + ii * blockDim.x;
174
+ const auto i = i_start + threadIdx .x * kILP + ii;
108
175
if (i < n && i < chunk_size)
109
176
dst[i] = src[ii];
110
177
}
111
178
}
112
179
180
+ template <typename dT, typename sT >
181
+ __device__ void store_args (
182
+ dT* dst,
183
+ sT * src,
184
+ const int64_t i_start,
185
+ const int64_t chunk_size,
186
+ const int64_t n) {
187
+ #pragma unroll
188
+ for (int ii = 0 ; ii < kILP ; ii++) {
189
+ // const int64_t i = i_start + threadIdx.x + ii * blockDim.x;
190
+ const auto i = i_start + threadIdx .x * kILP + ii;
191
+ if (i < n && i < chunk_size)
192
+ dst[i] = static_cast <dT>(src[ii]);
193
+ }
194
+ }
195
+
113
196
template <int res_arg_index, typename Op, typename T, typename opmath_t >
114
197
__device__ __forceinline__ void binary_op_scalar (
115
198
T r_args[][kILP ],
0 commit comments