OccamRazor commited on
Commit
8637c17
·
unverified ·
1 Parent(s): 85caa3f

Refactor validation and enumeration platform checks into functions to clean up ggml_vk_instance_init()

Browse files
Files changed (1) hide show
  1. ggml-vulkan.cpp +62 -37
ggml-vulkan.cpp CHANGED
@@ -1091,7 +1091,10 @@ static void ggml_vk_print_gpu_info(size_t idx) {
1091
  }
1092
  }
1093
 
1094
- static void ggml_vk_instance_init() {
 
 
 
1095
  if (vk_instance_initialized) {
1096
  return;
1097
  }
@@ -1102,54 +1105,40 @@ static void ggml_vk_instance_init() {
1102
  vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
1103
 
1104
  const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
1105
- #ifdef __APPLE__
1106
- bool portability_enumeration_ext = false;
1107
- // Check for portability enumeration extension for MoltenVK support
1108
- for (const auto& properties : instance_extensions) {
1109
- if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
1110
- portability_enumeration_ext = true;
1111
- break;
1112
- }
1113
  }
1114
- if (!portability_enumeration_ext) {
1115
- std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
 
1116
  }
1117
- #endif
1118
-
1119
- std::vector<const char*> layers = {
1120
- #ifdef GGML_VULKAN_VALIDATE
1121
- "VK_LAYER_KHRONOS_validation",
1122
- #endif
1123
- };
1124
- std::vector<const char*> extensions = {
1125
- #ifdef GGML_VULKAN_VALIDATE
1126
- "VK_EXT_validation_features",
1127
- #endif
1128
- };
1129
- #ifdef __APPLE__
1130
  if (portability_enumeration_ext) {
1131
  extensions.push_back("VK_KHR_portability_enumeration");
1132
  }
1133
- #endif
1134
  vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
1135
- #ifdef __APPLE__
1136
  if (portability_enumeration_ext) {
1137
  instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
1138
  }
1139
- #endif
1140
 
 
 
1141
 
1142
- #ifdef GGML_VULKAN_VALIDATE
1143
- const std::vector<vk::ValidationFeatureEnableEXT> features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
1144
- vk::ValidationFeaturesEXT validation_features = {
1145
- features_enable,
1146
- {},
1147
- };
1148
- validation_features.setPNext(nullptr);
1149
- instance_create_info.setPNext(&validation_features);
1150
 
1151
- std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
1152
- #endif
1153
  vk_instance.instance = vk::createInstance(instance_create_info);
1154
 
1155
  memset(vk_instance.initialized, 0, sizeof(bool) * GGML_VK_MAX_DEVICES);
@@ -5329,6 +5318,42 @@ GGML_CALL int ggml_backend_vk_reg_devices() {
5329
  return vk_instance.device_indices.size();
5330
  }
5331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5332
  // checks
5333
 
5334
  #ifdef GGML_VULKAN_CHECK_RESULTS
 
1091
  }
1092
  }
1093
 
1094
+ static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
1095
+ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
1096
+
1097
+ void ggml_vk_instance_init() {
1098
  if (vk_instance_initialized) {
1099
  return;
1100
  }
 
1105
  vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
1106
 
1107
  const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
1108
+ const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
1109
+ const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
1110
+
1111
+ std::vector<const char*> layers;
1112
+
1113
+ if (validation_ext) {
1114
+ layers.push_back("VK_LAYER_KHRONOS_validation");
 
1115
  }
1116
+ std::vector<const char*> extensions;
1117
+ if (validation_ext) {
1118
+ extensions.push_back("VK_EXT_validation_features");
1119
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
1120
  if (portability_enumeration_ext) {
1121
  extensions.push_back("VK_KHR_portability_enumeration");
1122
  }
 
1123
  vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
 
1124
  if (portability_enumeration_ext) {
1125
  instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
1126
  }
 
1127
 
1128
+ std::vector<vk::ValidationFeatureEnableEXT> features_enable;
1129
+ vk::ValidationFeaturesEXT validation_features;
1130
 
1131
+ if (validation_ext) {
1132
+ features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
1133
+ validation_features = {
1134
+ features_enable,
1135
+ {},
1136
+ };
1137
+ validation_features.setPNext(nullptr);
1138
+ instance_create_info.setPNext(&validation_features);
1139
 
1140
+ std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
1141
+ }
1142
  vk_instance.instance = vk::createInstance(instance_create_info);
1143
 
1144
  memset(vk_instance.initialized, 0, sizeof(bool) * GGML_VK_MAX_DEVICES);
 
5318
  return vk_instance.device_indices.size();
5319
  }
5320
 
5321
+ // Extension availability
5322
+ static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
5323
+ #ifdef GGML_VULKAN_VALIDATE
5324
+ bool portability_enumeration_ext = false;
5325
+ // Check for portability enumeration extension for MoltenVK support
5326
+ for (const auto& properties : instance_extensions) {
5327
+ if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
5328
+ return true;
5329
+ }
5330
+ }
5331
+ if (!portability_enumeration_ext) {
5332
+ std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
5333
+ }
5334
+ #endif
5335
+ return false;
5336
+
5337
+ UNUSED(instance_extensions);
5338
+ }
5339
+ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
5340
+ #ifdef __APPLE__
5341
+ bool portability_enumeration_ext = false;
5342
+ // Check for portability enumeration extension for MoltenVK support
5343
+ for (const auto& properties : instance_extensions) {
5344
+ if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
5345
+ return true;
5346
+ }
5347
+ }
5348
+ if (!portability_enumeration_ext) {
5349
+ std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
5350
+ }
5351
+ #endif
5352
+ return false;
5353
+
5354
+ UNUSED(instance_extensions);
5355
+ }
5356
+
5357
  // checks
5358
 
5359
  #ifdef GGML_VULKAN_CHECK_RESULTS