commit 16839b58d9b3c3162a67ce5d776b36d4d24e801f Author: mertalev <101130780+mertalev@users.noreply.github.com> Date: Wed Mar 5 11:25:38 2025 -0500 disable algo caching (attributed to @dmnieto in https://github.com/microsoft/onnxruntime/pull/19567) diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc index d7f47d07a8..4060a2af52 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.cc +++ b/onnxruntime/core/providers/rocm/nn/conv.cc @@ -127,7 +127,6 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) if (w_dims_changed) { s_.last_w_dims = gsl::make_span(w_dims); - s_.cached_benchmark_fwd_results.clear(); } ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last, channels_last)); @@ -277,35 +276,6 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) HIP_CALL_THROW(hipMalloc(&s_.b_zero, malloc_size)); HIP_CALL_THROW(hipMemsetAsync(s_.b_zero, 0, malloc_size, Stream(context))); } - - if (!s_.cached_benchmark_fwd_results.contains(x_dims_miopen)) { - miopenConvAlgoPerf_t perf; - int algo_count = 1; - const ROCMExecutionProvider* rocm_ep = static_cast(this->Info().GetExecutionProvider()); - static constexpr int num_algos = MIOPEN_CONVOLUTION_FWD_ALGO_COUNT; - size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos, rocm_ep->GetDeviceId()) - : AlgoSearchWorkspaceSize; - IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); - MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionForwardAlgorithm( - GetMiopenHandle(context), - s_.x_tensor, - s_.x_data, - s_.w_desc, - s_.w_data, - s_.conv_desc, - s_.y_tensor, - s_.y_data, - 1, // requestedAlgoCount - &algo_count, // returnedAlgoCount - &perf, - algo_search_workspace.get(), - max_ws_size, - false)); // Do not do exhaustive algo search. - s_.cached_benchmark_fwd_results.insert(x_dims_miopen, {perf.fwd_algo, perf.memory}); - } - const auto& perf = s_.cached_benchmark_fwd_results.at(x_dims_miopen); - s_.fwd_algo = perf.fwd_algo; - s_.workspace_bytes = perf.memory; } else { // set Y s_.Y = context->Output(0, TensorShape(s_.y_dims)); @@ -319,6 +289,31 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) s_.y_data = reinterpret_cast(s_.Y->MutableData()); } } + + miopenConvAlgoPerf_t perf; + int algo_count = 1; + const ROCMExecutionProvider* rocm_ep = static_cast(this->Info().GetExecutionProvider()); + static constexpr int num_algos = MIOPEN_CONVOLUTION_FWD_ALGO_COUNT; + size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos, rocm_ep->GetDeviceId()) + : AlgoSearchWorkspaceSize; + IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); + MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionForwardAlgorithm( + GetMiopenHandle(context), + s_.x_tensor, + s_.x_data, + s_.w_desc, + s_.w_data, + s_.conv_desc, + s_.y_tensor, + s_.y_data, + 1, // requestedAlgoCount + &algo_count, // returnedAlgoCount + &perf, + algo_search_workspace.get(), + max_ws_size, + false)); // Do not do exhaustive algo search. + s_.fwd_algo = perf.fwd_algo; + s_.workspace_bytes = perf.memory; return Status::OK(); } diff --git a/onnxruntime/core/providers/rocm/nn/conv.h b/onnxruntime/core/providers/rocm/nn/conv.h index bc9846203e..d54218f258 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.h +++ b/onnxruntime/core/providers/rocm/nn/conv.h @@ -108,9 +108,6 @@ class lru_unordered_map { list_type lru_list_; }; -// cached miopen descriptors -constexpr size_t MAX_CACHED_ALGO_PERF_RESULTS = 10000; - template struct MiopenConvState { // if x/w dims changed, update algo and miopenTensors @@ -148,9 +145,6 @@ struct MiopenConvState { decltype(AlgoPerfType().memory) memory; }; - lru_unordered_map cached_benchmark_fwd_results{MAX_CACHED_ALGO_PERF_RESULTS}; - lru_unordered_map cached_benchmark_bwd_results{MAX_CACHED_ALGO_PERF_RESULTS}; - // Some properties needed to support asymmetric padded Conv nodes bool post_slicing_required; TensorShapeVector slice_starts; diff --git a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc index 7447113fdf..a662e35b2e 100644 --- a/onnxruntime/core/providers/rocm/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/rocm/nn/conv_transpose.cc @@ -76,7 +76,6 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy if (w_dims_changed) { s_.last_w_dims = gsl::make_span(w_dims); - s_.cached_benchmark_bwd_results.clear(); } ConvTransposeAttributes::Prepare p; @@ -126,35 +125,29 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy } y_data = reinterpret_cast(p.Y->MutableData()); - - if (!s_.cached_benchmark_bwd_results.contains(x_dims)) { - IAllocatorUniquePtr algo_search_workspace = GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); - - miopenConvAlgoPerf_t perf; - int algo_count = 1; - MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardDataAlgorithm( - GetMiopenHandle(context), - s_.x_tensor, - x_data, - s_.w_desc, - w_data, - s_.conv_desc, - s_.y_tensor, - y_data, - 1, - &algo_count, - &perf, - algo_search_workspace.get(), - AlgoSearchWorkspaceSize, - false)); - s_.cached_benchmark_bwd_results.insert(x_dims, {perf.bwd_data_algo, perf.memory}); - } - - const auto& perf = s_.cached_benchmark_bwd_results.at(x_dims); - s_.bwd_data_algo = perf.bwd_data_algo; - s_.workspace_bytes = perf.memory; } + IAllocatorUniquePtr algo_search_workspace = GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); + miopenConvAlgoPerf_t perf; + int algo_count = 1; + MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionBackwardDataAlgorithm( + GetMiopenHandle(context), + s_.x_tensor, + x_data, + s_.w_desc, + w_data, + s_.conv_desc, + s_.y_tensor, + y_data, + 1, + &algo_count, + &perf, + algo_search_workspace.get(), + AlgoSearchWorkspaceSize, + false)); + s_.bwd_data_algo = perf.bwd_data_algo; + s_.workspace_bytes = perf.memory; + // The following block will be executed in case there has been no change in the shapes of the // input and the filter compared to the previous run if (!y_data) {