File tree Expand file tree Collapse file tree 2 files changed +17
-5
lines changed Expand file tree Collapse file tree 2 files changed +17
-5
lines changed Original file line number Diff line number Diff line change @@ -194,9 +194,13 @@ static size_t _parseChosenWorkspaceSize() {
194
194
}
195
195
size_t workspace_size = 76 *1024 ; /* Use 76 MB for hipBLASLt */
196
196
#else
197
+ #if defined(FBCODE_CAFFE2)
197
198
size_t workspace_size = 1024 ; /* default size in KiB according to #73328 */
199
+ #else
200
+ // default to CUBLAS_WORKSPACE_CONFIG workspace size
201
+ size_t workspace_size = at::cuda::getChosenWorkspaceSize () / 1024 ;
202
+ #endif
198
203
#endif
199
-
200
204
if (val.has_value ()) {
201
205
try {
202
206
workspace_size = std::stoi (val.value ());
@@ -236,7 +240,12 @@ struct CublasLtWorkspace {
236
240
CublasLtWorkspace () {
237
241
size = _getWorkspaceSize ();
238
242
#ifndef USE_ROCM
239
- static bool unified = c10::utils::check_env (" TORCH_CUBLASLT_UNIFIED_WORKSPACE" ) == true ;
243
+ constexpr auto envvar = " TORCH_CUBLASLT_UNIFIED_WORKSPACE" ;
244
+ #if defined(FBCODE_CAFFE2)
245
+ static bool unified = c10::utils::check_env (envvar) == true ;
246
+ #else
247
+ static bool unified = c10::utils::has_env (envvar) ? c10::utils::check_env (envvar) == true : true ;
248
+ #endif
240
249
if (unified) {
241
250
auto cublasWorkspaceSize = at::cuda::getChosenWorkspaceSize ();
242
251
if (cublasWorkspaceSize < size) {
Original file line number Diff line number Diff line change @@ -127,10 +127,13 @@ size_t parseChosenWorkspaceSize() {
<
96B2
td data-grid-cell-id="diff-27af6258be3d3a27c90cbebe9839f4754280ecf7cddd00b4dc97ffc2037e8f97-127-127-0" data-selected="false" role="gridcell" style="background-color:var(--bgColor-default);text-align:center" tabindex="-1" valign="top" class="focusable-grid-cell diff-line-number position-relative diff-line-number-neutral left-side">127
127
const bool gfx94_95 = at::detail::getCUDAHooks ().isGPUArch ({" gfx94" , " gfx95" });
128
128
const size_t default_size = gfx94_95 ? 1024 * 128 * 1024 : 1024 * 32 * 1024 ;
129
129
#else
130
- /* :4096:2:16:8 default, 32MiB for Hopper */
130
+ /* :4096:2:16:8 default, 32MiB for Hopper/Blackwell, 12MiB for GeForce Blackwell */
131
131
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties ();
132
- const bool sm90 = properties != nullptr && properties->major == 9 && properties->minor == 0 ;
133
- const size_t default_size = sm90 ? 4096 * 8 * 1024 : 4096 * 1024 * 2 + 16 * 1024 * 8 ;
132
+ const bool sm90or100 = properties != nullptr && (properties->major == 9 || properties->major == 10 ) && properties->minor == 0 ;
133
+ const bool sm120 = properties != nullptr && properties->major == 12 && properties->minor == 0 ;
134
+ constexpr size_t sm90or100size = 32768 * 1024 ;
135
+ constexpr size_t sm120size = 12288 * 1024 ;
136
+ const size_t default_size = sm90or100 ? sm90or100size : sm120 ? sm120size : 4096 * 1024 * 2 + 16 * 1024 * 8 ;
134
137
#endif
135
138
136
139
if (val) {
You can’t perform that action at this time.
0 commit comments