blob: 8a0579e3d9992cdb5daeb285fc2cb5e48bd259b2 [file] [log] [blame]
diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.cc b/tensorflow/lite/delegates/gpu/common/gpu_info.cc
index a74ca2760c6..7094e54bf33 100644
--- a/tensorflow/lite/delegates/gpu/common/gpu_info.cc
+++ b/tensorflow/lite/delegates/gpu/common/gpu_info.cc
@@ -462,6 +462,14 @@ void GetGpuInfoFromDeviceDescription(const std::string& gpu_description,
std::string lowered = gpu_description;
absl::AsciiStrToLower(&lowered);
gpu_info->vendor = GetGpuVendor(lowered);
+
+ // Because clvk is an OpenCL layer on top of vulkan, it does not react to CL
+ // optimisation as native CL implementation does. For the time being, let's
+ // manage it manually with explicit condition in the code.
+ if (gpu_info->IsApiOpenCl() && gpu_info->opencl_info.IsCLVK()) {
+ gpu_info->vendor = GpuVendor::kUnknown;
+ }
+
if (gpu_info->IsAdreno()) {
gpu_info->adreno_info = AdrenoInfo(lowered);
} else if (gpu_info->IsApple()) {
diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.h b/tensorflow/lite/delegates/gpu/common/gpu_info.h
index e806a120564..e066a9579ed 100644
--- a/tensorflow/lite/delegates/gpu/common/gpu_info.h
+++ b/tensorflow/lite/delegates/gpu/common/gpu_info.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "absl/strings/match.h"
namespace tflite {
namespace gpu {
@@ -351,6 +352,8 @@ struct OpenClInfo {
bool supports_rgba_f32_tex2d = false;
bool IsImage2dFromBufferSupported() const;
+
+ bool IsCLVK() const { return absl::StrContains(platform_version, "clvk");}
};
enum class MetalLanguageVersion {
diff --git a/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.cc b/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.cc
index 72979d0764f..330e60d750b 100644
--- a/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.cc
+++ b/tensorflow/lite/delegates/gpu/common/tasks/conv_powervr.cc
@@ -256,7 +256,9 @@ void ConvPowerVR::GenerateCode(const GpuInfo& gpu_info) {
if (gpu_info.IsMali()) {
compiler_options_.push_back(CompilerOptions::kClFastRelaxedMath);
}
- if (conv_params_.IsPrivateMemBroadcast() && gpu_info.IsCL20OrHigher()) {
+ if (conv_params_.IsPrivateMemBroadcast() &&
+ (gpu_info.IsCL20OrHigher() ||
+ gpu_info.opencl_info.IsCLVK())) {
compiler_options_.push_back(CompilerOptions::kCl20);
}
bool kernel_is_trivial =
@@ -1291,7 +1293,8 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
} else {
conv_params.weights_upload_type = WeightsUploadType::TEXTURES_MEM_X4;
}
- } else if (gpu_info.IsIntel()) {
+ } else if (gpu_info.IsIntel() ||
+ (gpu_info.IsApiOpenCl() && gpu_info.opencl_info.IsCLVK())) {
if (different_weights_for_height) {
work_group_size_ = int3(16, 1, 1);
work_group_launch_order_ = int3(0, 1, 2);
@@ -1304,17 +1307,32 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(
}
conv_params.block_size = int4(1, 1, 1, 4);
conv_params.src_depth_loop_size = 1;
- int sub_group_size = 16;
const bool supports_subgroups =
gpu_info.SupportsExtension("cl_khr_subgroups") ||
- gpu_info.SupportsExtension("cl_intel_subgroups");
- if (definition.precision != CalculationsPrecision::F32_F16 &&
- supports_subgroups &&
- gpu_info.SupportsExtension("cl_intel_required_subgroup_size") &&
- gpu_info.SupportsSubGroupWithSize(sub_group_size)) {
- conv_params.weights_upload_type =
- WeightsUploadType::PRIVATE_MEM_SIMD_BROADCAST;
- conv_params.simd_size = sub_group_size;
+ gpu_info.SupportsExtension("cl_intel_subgroups") ||
+ gpu_info.opencl_info.IsCLVK();
+ if (supports_subgroups) {
+ const int kSubGroupSize = 16;
+ const bool supports_subgroup_size_control =
+ gpu_info.SupportsExtension("cl_intel_required_subgroup_size");
+ if (supports_subgroup_size_control &&
+ gpu_info.SupportsSubGroupWithSize(kSubGroupSize)) {
+ conv_params.weights_upload_type =
+ WeightsUploadType::PRIVATE_MEM_SIMD_BROADCAST;
+ conv_params.simd_size = kSubGroupSize;
+ } else if (gpu_info.opencl_info.IsCLVK()) {
+ // It will work because of specific driver using subgroup size 16
+ conv_params.weights_upload_type =
+ WeightsUploadType::PRIVATE_MEM_SIMD_BROADCAST;
+ conv_params.simd_size = 16;
+ } else {
+ // no support of subgroup size control
+ // only smallest subgroup size (8) can be used safely, otherwise
+ // correctness can not be guaranteed
+ // conv_params.weights_upload_type =
+ // WeightsUploadType::PRIVATE_MEM_SIMD_BROADCAST;
+ // conv_params.simd_size = 8;
+ }
} else {
conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS;
}