@@ -76,6 +76,34 @@ __device__ bool init_args(
76
76
return all_aligned;
77
77
}
78
78
79
+ template <
80
+ int depth,
81
+ typename param_type,
82
+ typename grad_type,
83
+ typename exp_avg_type,
84
+ typename exp_avg_sq_type>
85
+ __device__ bool init_args_mixed_prec (
86
+ param_type** param_args,
87
+ grad_type** grad_args,
88
+ exp_avg_type** exp_avg_args,
89
+ exp_avg_sq_type** exp_avg_sq_args,
90
+ FusedOptimizerTensorListMetadata<depth>& tl,
91
+ const int64_t chunk_idx,
92
+ const int64_t chunk_size,
93
+ const int64_t tensor_loc) {
94
+ *param_args =
95
+ (param_type*)tl.addresses [0 ][tensor_loc] + chunk_idx * chunk_size;
96
+ *grad_args = (grad_type*)tl.addresses [1 ][tensor_loc] + chunk_idx * chunk_size;
97
+ *exp_avg_args =
98
+ (exp_avg_type*)tl.addresses [2 ][tensor_loc] + chunk_idx * chunk_size;
99
+ *exp_avg_sq_args =
100
+ (exp_avg_sq_type*)tl.addresses [3 ][tensor_loc] + chunk_idx * chunk_size;
101
+
102
+ bool all_aligned = is_aligned (*param_args) && is_aligned (*grad_args) &&
103
+ is_aligned (*exp_avg_args) && is_aligned (*exp_avg_sq_args);
104
+ return all_aligned;
105
+ }
106
+
79
107
template <int depth, typename T>
80
108
__device__ void load_args (
81
109
T r_args[][kILP ],
@@ -95,6 +123,43 @@ __device__ void load_args(
95
123
}
96
124
}
97
125
126
+ template <
127
+ typename T,
128
+ typename param_type,
129
+ typename grad_type,
130
+ typename exp_avg_type,
131
+ typename exp_avg_sq_type>
132
+ __device__ void load_args (
133
+ T r_args[][kILP ],
134
+ const param_type* param_args,
135
+ const grad_type* grad_args,
136
+ const exp_avg_type* exp_avg_args,
137
+ const exp_avg_sq_type* exp_avg_sq_args,
138
+ const int64_t i_start,
139
+ const int64_t chunk_size,
140
+ const int64_t n) {
141
+ #pragma unroll
142
+ for (int ii = 0 ; ii < kILP ; ii++) {
143
+ const auto i = i_start + threadIdx .x + ii * blockDim .x ;
144
+ r_args[0 ][ii] = 0 ;
145
+ if (i < n && i < chunk_size) {
146
+ r_args[0 ][ii] = static_cast <T>(param_args[i]);
147
+ }
148
+ r_args[1 ][ii] = 0 ;
149
+ if (i < n && i < chunk_size) {
150
+ r_args[1 ][ii] = static_cast <T>(grad_args[i]);
151
+ }
152
+ r_args[2 ][ii] = 0 ;
153
+ if (i < n && i < chunk_size) {
154
+ r_args[2 ][ii] = static_cast <T>(exp_avg_args[i]);
155
+ }
156
+ r_args[3 ][ii] = 0 ;
157
+ if (i < n && i < chunk_size) {
158
+ r_args[3 ][ii] = static_cast <T>(exp_avg_sq_args[i]);
159
+ }
160
+ }
161
+ }
162
+
98
163
template <typename T>
99
164
__device__ void store_args (
100
165
T* dst,
@@ -110,6 +175,21 @@ __device__ void store_args(
110
175
}
111
176
}
112
177
178
+ template <typename dT, typename sT >
179
+ __device__ void store_args (
180
+ dT* dst,
181
+ sT * src,
182
+ const int64_t i_start,
183
+ const int64_t chunk_size,
184
+ const int64_t n) {
185
+ #pragma unroll
186
+ for (int ii = 0 ; ii < kILP ; ii++) {
187
+ const int64_t i = i_start + threadIdx .x + ii * blockDim .x ;
188
+ if (i < n && i < chunk_size)
189
+ dst[i] = static_cast <dT>(src[ii]);
190
+ }
191
+ }
192
+
113
193
template <int res_arg_index, typename Op, typename T, typename opmath_t >
114
194
__device__ __forceinline__ void binary_op_scalar (
115
195
T r_args[][kILP ],