@@ -71,8 +71,10 @@ struct ggml_backend_kompute_context {
71
71
std::string name;
72
72
std::shared_ptr<vk::DescriptorPool> pool;
73
73
74
+ ggml_backend_buffer_type buft;
75
+
74
76
ggml_backend_kompute_context (int device)
75
- : device(device), name(ggml_kompute_format_name(device)) {}
77
+ : device(device), name(ggml_kompute_format_name(device)) { buft. context = nullptr ; }
76
78
};
77
79
78
80
// FIXME: It would be good to consolidate the kompute manager and the kompute context into one object
@@ -1918,24 +1920,25 @@ static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
1918
1920
};
1919
1921
1920
1922
ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type (int device) {
1921
- static std::vector<ggml_backend_buffer_type> bufts = []() {
1922
- std::vector<ggml_backend_buffer_type> vec;
1923
- auto devices = ggml_vk_available_devices_internal (0 );
1924
- vec.reserve (devices.size ());
1923
+ if (!s_kompute_context)
1924
+ s_kompute_context = new ggml_backend_kompute_context (device);
1925
1925
1926
- for (const auto & dev : devices) {
1927
- vec.push_back ({
1928
- /* .iface = */ ggml_backend_kompute_buffer_type_interface,
1929
- /* .context = */ new ggml_backend_kompute_buffer_type_context (dev.index , dev.bufferAlignment , dev.maxAlloc )
1930
- });
1926
+ auto * buft = &s_kompute_context->buft ;
1927
+ if (!buft->context ) {
1928
+ auto devices = ggml_vk_available_devices_internal (0 );
1929
+ for (std::size_t i = 0 ; i < devices.size (); i++) {
1930
+ if (device == devices[i].index ) {
1931
+ buft->context = new ggml_backend_kompute_buffer_type_context (
1932
+ devices[i].index ,
1933
+ devices[i].bufferAlignment ,
1934
+ devices[i].maxAlloc );
1935
+ buft->iface = ggml_backend_kompute_buffer_type_interface;
1936
+ break ;
1937
+ }
1931
1938
}
1932
- return vec;
1933
- }();
1939
+ }
1934
1940
1935
- auto it = std::find_if (bufts.begin (), bufts.end (), [device](const ggml_backend_buffer_type & t) {
1936
- return device == static_cast <ggml_backend_kompute_buffer_type_context *>(t.context )->device ;
1937
- });
1938
- return it < bufts.end () ? &*it : nullptr ;
1941
+ return buft;
1939
1942
}
1940
1943
1941
1944
// backend
@@ -1974,8 +1977,8 @@ static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struc
1974
1977
}
1975
1978
1976
1979
static bool ggml_backend_kompute_supports_buft (ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
1977
- GGML_UNUSED (backend);
1978
- return buft-> iface . get_name == ggml_backend_kompute_buffer_type_get_name ;
1980
+ auto *ctx = static_cast <ggml_backend_kompute_context *> (backend-> context );
1981
+ return &ctx-> buft == buft ;
1979
1982
}
1980
1983
1981
1984
static struct ggml_backend_i kompute_backend_i = {
@@ -2007,8 +2010,8 @@ static ggml_guid_t ggml_backend_kompute_guid() {
2007
2010
}
2008
2011
2009
2012
ggml_backend_t ggml_backend_kompute_init (int device) {
2010
- GGML_ASSERT ( s_kompute_context == nullptr );
2011
- s_kompute_context = new ggml_backend_kompute_context (device);
2013
+ if (! s_kompute_context)
2014
+ s_kompute_context = new ggml_backend_kompute_context (device);
2012
2015
2013
2016
ggml_backend_t kompute_backend = new ggml_backend {
2014
2017
/* .guid = */ ggml_backend_kompute_guid (),
0 commit comments