Skip to content

Commit 18f1571

Browse files
[NPUW] Introduce quantized vocab handling (#30553)
1 parent 8c7b54e commit 18f1571

File tree

14 files changed

+662
-6
lines changed

14 files changed

+662
-6
lines changed

src/plugins/intel_npu/src/al/include/intel_npu/config/npuw.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ DEFINE_OPT(NPUW_DQ_FULL, bool, true, npuw::partitioning::dyn_quant_full, RunTime
8282
DEFINE_OPT(NPUW_PMM, std::string, "2", npuw::partitioning::par_matmul_merge_dims, RunTime);
8383
DEFINE_OPT(NPUW_SLICE_OUT, bool, false, npuw::partitioning::slice_out, RunTime);
8484
DEFINE_OPT(NPUW_HOST_GATHER, bool, true, npuw::partitioning::host_gather, RunTime);
85+
DEFINE_OPT(NPUW_HOST_GATHER_QUANT, bool, false, npuw::partitioning::gather_quant, RunTime);
8586
DEFINE_OPT(NPUW_SPATIAL, bool, false, npuw::partitioning::spatial, RunTime);
8687
DEFINE_OPT(NPUW_F16IC, bool, true, npuw::partitioning::f16_interconnect, RunTime);
8788
DEFINE_OPT(NPUW_SPATIAL_NWAY, std::size_t, 128, npuw::partitioning::spatial_nway, RunTime);

src/plugins/intel_npu/src/al/include/intel_npu/npuw_private_properties.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,14 @@ static constexpr ov::Property<bool> f16_interconnect{"NPUW_F16IC"};
250250
*/
251251
static constexpr ov::Property<bool> host_gather{"NPUW_HOST_GATHER"};
252252

253+
/**
254+
* @brief
255+
* Type: boolean
256+
* When applicable, do embedding gather on host but leave it quantized.
257+
* Default value: false.
258+
*/
259+
static constexpr ov::Property<bool> gather_quant{"NPUW_HOST_GATHER_QUANT"};
260+
253261
/**
254262
* @brief
255263
* Type: std::string.

src/plugins/intel_npu/src/al/src/config/npuw.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ void intel_npu::registerNPUWOptions(OptionsDesc& desc) {
3434
desc.add<NPUW_SPATIAL_NWAY>();
3535
desc.add<NPUW_SPATIAL_DYN>();
3636
desc.add<NPUW_HOST_GATHER>();
37+
desc.add<NPUW_HOST_GATHER_QUANT>();
3738
desc.add<NPUW_F16IC>();
3839
desc.add<NPUW_DCOFF_TYPE>();
3940
desc.add<NPUW_DCOFF_SCALE>();

src/plugins/intel_npu/src/plugin/npuw/base_sync_infer_request.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,12 +465,78 @@ void ov::npuw::IBaseInferRequest::bind_global_params(std::size_t idx, RqPtr requ
465465
const auto& vocab = comp_model_desc.closure[comp_model_desc.host_gather.src_idx - comp_model_desc.param_base];
466466
const auto& lport = comp_model_desc.compiled_model->inputs()[comp_model_desc.host_gather.idx_idx];
467467
const auto lookup = request->get_tensor(lport);
468+
468469
ov::npuw::util::gather(ov::get_tensor_impl(vocab), lookup, gather);
469470
}
470471

472+
// Run host-side quantized gather, if required
473+
handle_quant_host_gather(idx, request);
474+
471475
LOG_DEBUG("Done");
472476
}
473477

478+
void ov::npuw::IBaseInferRequest::handle_quant_host_gather(std::size_t idx, RqPtr request) {
479+
auto& comp_model_desc = m_npuw_model->m_compiled_submodels[idx];
480+
481+
if (comp_model_desc.quant_unpack_gather.dst_idx != -1) {
482+
NPUW_ASSERT(comp_model_desc.quant_unpack_gather.idx_idx != -1 &&
483+
comp_model_desc.quant_unpack_gather.src_w_idx != -1);
484+
485+
const auto& lport = comp_model_desc.compiled_model->inputs()[comp_model_desc.quant_unpack_gather.idx_idx];
486+
const auto& lookup = request->get_tensor(lport);
487+
488+
const auto& gport = comp_model_desc.compiled_model->inputs()[comp_model_desc.quant_unpack_gather.dst_idx];
489+
const auto& gather = request->get_tensor(gport);
490+
491+
const auto& wport = comp_model_desc.compiled_model->inputs()[comp_model_desc.quant_unpack_gather.src_w_idx];
492+
const auto& vocabw = request->get_tensor(wport);
493+
494+
auto ids_shape = lookup->get_shape();
495+
496+
auto get_gathered_shape = [&ids_shape](const ov::Shape& shape) {
497+
return ov::Shape{1, ids_shape[1], shape.size() == 3 ? shape[1] * shape[2] : shape[1]};
498+
};
499+
500+
ov::Tensor gatherw(vocabw->get_element_type(), get_gathered_shape(vocabw->get_shape()));
501+
// Gather weight
502+
ov::npuw::util::gather(vocabw, lookup, ov::get_tensor_impl(gatherw));
503+
504+
if (comp_model_desc.quant_unpack_gather.src_z_idx != -1 &&
505+
comp_model_desc.quant_unpack_gather.src_s_idx != -1) {
506+
const auto& zport = comp_model_desc.compiled_model->inputs()[comp_model_desc.quant_unpack_gather.src_z_idx];
507+
const auto& vocabz = request->get_tensor(zport);
508+
509+
const auto& sport = comp_model_desc.compiled_model->inputs()[comp_model_desc.quant_unpack_gather.src_s_idx];
510+
const auto& vocabs = request->get_tensor(sport);
511+
512+
ov::Tensor gatherz(vocabz->get_element_type(), get_gathered_shape(vocabz->get_shape()));
513+
ov::Tensor gathers(vocabs->get_element_type(), get_gathered_shape(vocabs->get_shape()));
514+
// Gather first
515+
ov::npuw::util::gather(vocabz, lookup, ov::get_tensor_impl(gatherz));
516+
ov::npuw::util::gather(vocabs, lookup, ov::get_tensor_impl(gathers));
517+
518+
// Then unpack
519+
ov::npuw::util::unpack(ov::get_tensor_impl(gatherw),
520+
ov::get_tensor_impl(gatherz),
521+
ov::get_tensor_impl(gathers),
522+
gather);
523+
} else if (comp_model_desc.quant_unpack_gather.src_s_idx != -1) {
524+
const auto& sport = comp_model_desc.compiled_model->inputs()[comp_model_desc.quant_unpack_gather.src_s_idx];
525+
const auto& vocabs = request->get_tensor(sport);
526+
527+
ov::Tensor gathers(vocabs->get_element_type(), get_gathered_shape(vocabs->get_shape()));
528+
// Gather first
529+
ov::npuw::util::gather(vocabs, lookup, ov::get_tensor_impl(gathers));
530+
531+
// Then unpack
532+
ov::npuw::util::unpack(ov::get_tensor_impl(gatherw), ov::get_tensor_impl(gathers), gather);
533+
} else {
534+
// Already gathered above - just unpack
535+
ov::npuw::util::unpack(ov::get_tensor_impl(gatherw), gather);
536+
}
537+
}
538+
}
539+
474540
void ov::npuw::IBaseInferRequest::bind_global_results(std::size_t idx, RqPtr request) {
475541
LOG_DEBUG("Binding results for Subgraph[" << idx << "]");
476542
LOG_BLOCK();

src/plugins/intel_npu/src/plugin/npuw/base_sync_infer_request.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ class IBaseInferRequest : public ov::ISyncInferRequest {
149149
void unpack_closure(std::size_t idx, RqPtr request);
150150
virtual void bind_global_params(std::size_t idx, RqPtr request);
151151
virtual void bind_global_results(std::size_t idx, RqPtr request);
152+
void handle_quant_host_gather(std::size_t idx, RqPtr request);
152153

153154
void dump_input_tensors(std::size_t idx);
154155
void dump_output_tensors(std::size_t idx);

src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ ov::npuw::CompiledModel::CompiledModel(const std::shared_ptr<ov::Model>& model,
367367
LOG_INFO("Subgraph[" << id << "] is a function call to [" << compiled_fcn_iter->second << "]");
368368
}
369369
m_compiled_submodels[id].host_gather = subgraph._host_gather;
370+
m_compiled_submodels[id].quant_unpack_gather = subgraph._quant_unpack_gather;
370371
m_compiled_submodels[id].param_base = fcn_template._param_offset;
371372
m_compiled_submodels[id].closure = subgraph._closure;
372373
m_compiled_submodels[id].lazy_closure = subgraph._lazy_closure;
@@ -541,6 +542,12 @@ void ov::npuw::CompiledModel::CompiledModelDesc::serialize(std::ostream& stream,
541542
write(stream, host_gather.src_idx);
542543
write(stream, host_gather.idx_idx);
543544

545+
write(stream, quant_unpack_gather.dst_idx);
546+
write(stream, quant_unpack_gather.src_w_idx);
547+
write(stream, quant_unpack_gather.src_z_idx);
548+
write(stream, quant_unpack_gather.src_s_idx);
549+
write(stream, quant_unpack_gather.idx_idx);
550+
544551
write(stream, spatial);
545552

546553
write(stream, is_remote);
@@ -609,6 +616,12 @@ void ov::npuw::CompiledModel::CompiledModelDesc::deserialize(std::istream& strea
609616
read(stream, host_gather.src_idx);
610617
read(stream, host_gather.idx_idx);
611618

619+
read(stream, quant_unpack_gather.dst_idx);
620+
read(stream, quant_unpack_gather.src_w_idx);
621+
read(stream, quant_unpack_gather.src_z_idx);
622+
read(stream, quant_unpack_gather.src_s_idx);
623+
read(stream, quant_unpack_gather.idx_idx);
624+
612625
read(stream, spatial);
613626

614627
read(stream, is_remote);
@@ -1699,6 +1712,7 @@ void ov::npuw::CompiledModel::implement_properties() {
16991712
BIND(npuw::partitioning::spatial_nway, NPUW_SPATIAL_NWAY),
17001713
BIND(npuw::partitioning::spatial_dyn, NPUW_SPATIAL_DYN),
17011714
BIND(npuw::partitioning::host_gather, NPUW_HOST_GATHER),
1715+
BIND(npuw::partitioning::gather_quant, NPUW_HOST_GATHER_QUANT),
17021716
BIND(npuw::partitioning::funcall_for_all, NPUW_FUNCALL_FOR_ALL),
17031717
BIND(npuw::partitioning::f16_interconnect, NPUW_F16IC),
17041718
BIND(npuw::partitioning::dcoff_type, NPUW_DCOFF_TYPE),

src/plugins/intel_npu/src/plugin/npuw/compiled_model.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ class CompiledModel : public ov::npuw::ICompiledModel {
152152
std::optional<std::size_t> replaced_by;
153153

154154
Subgraph::Gather host_gather;
155+
Subgraph::QuantUnpackGather quant_unpack_gather;
155156
std::optional<ov::npuw::compiled::Spatial> spatial;
156157

157158
// FIXME: This is a 1:1 copy of the ov::npuw::Subgraph structure

src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp

Lines changed: 120 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ class Partitioner {
322322
void saveTinyConstants(const std::string& func_name);
323323
void saveScaleFactors(const std::string& func_name);
324324
void saveRepeatedConstants(const std::string& func_name);
325+
void saveTailDictConstants(const std::string& func_name);
325326
void matchParameters(const std::string& func_name);
326327
void matchResults(const std::string& func_name);
327328
void createFunction(const std::string& func_name);
@@ -1430,6 +1431,42 @@ void Partitioner::saveRepeatedConstants(const std::string& func_name) {
14301431
}
14311432
}
14321433

1434+
void Partitioner::saveTailDictConstants(const std::string& func_name) {
1435+
if (!cfg.get<::intel_npu::NPUW_HOST_GATHER_QUANT>()) {
1436+
// No need to preserve as constants
1437+
return;
1438+
}
1439+
1440+
// Depending on the config we might want to save vocab in the tail subgraph as a Constant.
1441+
auto& func_group = all_functions.at(func_name);
1442+
auto& subgr_group = func_group.refs;
1443+
1444+
if (subgr_group.size() > 1) {
1445+
// Skip the repeated block
1446+
return;
1447+
}
1448+
1449+
LOG_VERB("Trying to preserve some (tail) constants for " << func_name << " in model " << model->get_friendly_name()
1450+
<< "...");
1451+
LOG_BLOCK();
1452+
1453+
auto& model_group = func_group.mdls;
1454+
1455+
using CPtr = std::shared_ptr<ov::op::v0::Constant>;
1456+
std::vector<CPtr> to_keep;
1457+
1458+
ov::pass::GraphRewrite rewr;
1459+
rewr.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulCWu>(std::ref(to_keep));
1460+
rewr.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulCWf8>(std::ref(to_keep));
1461+
rewr.run_on_model(model_group.front());
1462+
1463+
for (auto&& const_to_keep : to_keep) {
1464+
LOG_DEBUG("[KEEP] " << const_to_keep);
1465+
func_group.consts_to_keep.insert(const_to_keep);
1466+
}
1467+
LOG_VERB("Done");
1468+
}
1469+
14331470
void Partitioner::matchParameters(const std::string& func_name) {
14341471
LOG_VERB("Matching parameters for function " << func_name << " in model " << model->get_friendly_name() << "...");
14351472
LOG_BLOCK();
@@ -1874,12 +1911,20 @@ void Partitioner::optimize(const std::string& func_name) {
18741911
ctx.is_spatial = f._spatial.has_value();
18751912
ctx.pmm_dims = cfg.get<::intel_npu::NPUW_PMM>();
18761913

1914+
if (cfg.get<::intel_npu::NPUW_HOST_GATHER_QUANT>() && cfg.get<::intel_npu::NPUW_HOST_GATHER>()) {
1915+
NPUW_ASSERT(false && "Conflicting configuration: NPUW_HOST_GATHER and NPUW_HOST_GATHER_QUANT should not be "
1916+
"enabled together!");
1917+
}
1918+
18771919
// Run Head/Tail passes
18781920
ov::pass::GraphRewrite rewr;
1879-
rewr.add_matcher<ov::npuw::patterns::opt::DQUnpackDictGatheru>(std::ref(ctx));
1880-
rewr.add_matcher<ov::npuw::patterns::opt::DQUnpackDictGatherGQi>(std::ref(ctx));
1881-
rewr.add_matcher<ov::npuw::patterns::opt::DQUnpackDictMatMulCWu>(std::ref(ctx));
1882-
rewr.add_matcher<ov::npuw::patterns::opt::DQUnpackDictMatMulCWf8>(std::ref(ctx));
1921+
if (!cfg.get<::intel_npu::NPUW_HOST_GATHER_QUANT>()) {
1922+
rewr.add_matcher<ov::npuw::patterns::opt::DQUnpackDictGatheru>(std::ref(ctx));
1923+
rewr.add_matcher<ov::npuw::patterns::opt::DQUnpackDictGatherGQi>(std::ref(ctx));
1924+
rewr.add_matcher<ov::npuw::patterns::opt::DQUnpackDictMatMulCWu>(std::ref(ctx));
1925+
rewr.add_matcher<ov::npuw::patterns::opt::DQUnpackDictMatMulCWf8>(std::ref(ctx));
1926+
}
1927+
18831928
// NB: This pass is disabled for reason! It doesn't make things better
18841929
// rewr.add_matcher<ov::npuw::patterns::opt::DQUnpackDictMatMulGQi>(std::ref(ctx));
18851930
rewr.add_matcher<ov::npuw::patterns::opt::CompressDictMatMulf32>(std::ref(ctx));
@@ -1888,6 +1933,29 @@ void Partitioner::optimize(const std::string& func_name) {
18881933
rewr.add_matcher<ov::npuw::patterns::opt::ConvToMatmul>(std::ref(ctx));
18891934
rewr.run_on_model(f._model);
18901935

1936+
// Quantized Gather + Unpack on host in the runtime
1937+
if (cfg.get<::intel_npu::NPUW_HOST_GATHER_QUANT>()) {
1938+
// FIXME: since we are running it after lifted Gather,
1939+
// we need to first try to match Asymm or Symm patterns.
1940+
// Otherwise smaller HostGatherQuant might be matched first and break
1941+
// the quantization logic.
1942+
{
1943+
ov::pass::GraphRewrite rewr2;
1944+
rewr2.add_matcher<ov::npuw::patterns::opt::HostGatherQuantAsymm>(std::ref(ctx));
1945+
rewr2.run_on_model(f._model);
1946+
}
1947+
{
1948+
ov::pass::GraphRewrite rewr2;
1949+
rewr2.add_matcher<ov::npuw::patterns::opt::HostGatherQuantSymm>(std::ref(ctx));
1950+
rewr2.run_on_model(f._model);
1951+
}
1952+
{
1953+
ov::pass::GraphRewrite rewr2;
1954+
rewr2.add_matcher<ov::npuw::patterns::opt::HostGatherQuant>(std::ref(ctx));
1955+
rewr2.run_on_model(f._model);
1956+
}
1957+
}
1958+
18911959
// Move Gather to host, if required
18921960
if (cfg.get<::intel_npu::NPUW_HOST_GATHER>()) {
18931961
ov::pass::GraphRewrite rewr2;
@@ -1992,6 +2060,30 @@ void Partitioner::optimize(const std::string& func_name) {
19922060
}
19932061
}
19942062

2063+
// Host-side quantized gather, pt 1. Add new parameters first
2064+
if (ctx.params_to_quant_gather_unpack) {
2065+
auto& params_to_quant_gather_unpack = *ctx.params_to_quant_gather_unpack;
2066+
for (const auto& param_new_and_unpack : params_to_quant_gather_unpack.params_to_runtime_unpack_gather) {
2067+
// New input in the graph
2068+
new_params.push_back(param_new_and_unpack.first);
2069+
// Note: don't remove w, z and s params here to keep them shared with the quant vocab in tail
2070+
for (auto&& funcall : func_group.refs) {
2071+
auto new_elem_type = param_new_and_unpack.first->get_element_type();
2072+
const auto& new_shape = param_new_and_unpack.first->get_shape();
2073+
// Note: no allocation needed for this tensor - set to _closure and dummy in _lazy_closure
2074+
// FIXME: It turns out this tensor will be completely unused.
2075+
// It will just sit in the memory to do nothing.
2076+
// Most likely it may stay empty since we need a 1:1 matching between
2077+
// closure tensors and parameters (minus base).
2078+
// Based on our logic (when tensors get transferred from lazy tensors via bank
2079+
// to the closure), this tensor should be non-empty to avoid this process.
2080+
funcall.get()._closure.push_back(ov::Tensor(new_elem_type, new_shape));
2081+
funcall.get()._lazy_closure.push_back(LazyTensor());
2082+
funcall.get()._is_lazy_unpack.push_back(false);
2083+
}
2084+
}
2085+
}
2086+
19952087
// Add all new parameters introduced by this change
19962088
f._model->add_parameters(new_params);
19972089

@@ -2031,6 +2123,29 @@ void Partitioner::optimize(const std::string& func_name) {
20312123
}
20322124
}
20332125

2126+
// Host-side quantized gather, pt. 2: Write the gather mappings to funcall
2127+
if (ctx.params_to_quant_gather_unpack) {
2128+
auto& params_to_quant_gather_unpack = *ctx.params_to_quant_gather_unpack;
2129+
for (const auto& param_new_and_unpack_gather :
2130+
params_to_quant_gather_unpack.params_to_runtime_unpack_gather) {
2131+
// New param in the graph
2132+
auto gather_dst_id = f._model->get_parameter_index(param_new_and_unpack_gather.first);
2133+
// Orig params to gather from
2134+
auto gather_w_id = f._model->get_parameter_index(param_new_and_unpack_gather.second.w);
2135+
auto gather_z_id = f._model->get_parameter_index(param_new_and_unpack_gather.second.z);
2136+
auto gather_s_id = f._model->get_parameter_index(param_new_and_unpack_gather.second.s);
2137+
// Original pids
2138+
auto gather_idx_id = f._model->get_parameter_index(params_to_quant_gather_unpack.pids);
2139+
for (auto&& funcall : func_group.refs) {
2140+
funcall.get()._quant_unpack_gather = ov::npuw::Subgraph::QuantUnpackGather{gather_dst_id,
2141+
gather_w_id,
2142+
gather_z_id,
2143+
gather_s_id,
2144+
gather_idx_id};
2145+
}
2146+
}
2147+
}
2148+
20342149
// FIXME: workaround
20352150
// Set lazy unpack indexes not to be unpacked in DCOFF
20362151
for (auto&& fref : func_group.refs) {
@@ -2344,6 +2459,7 @@ ov::npuw::Partitioning ov::npuw::getPartitioning(const std::shared_ptr<ov::Model
23442459
p.propagateConvertsOut(func_group);
23452460
p.sanityCheck(func_group);
23462461
p.saveRepeatedConstants(func_group);
2462+
p.saveTailDictConstants(func_group);
23472463
p.matchParameters(func_group);
23482464
p.matchResults(func_group);
23492465
p.matchRepeatedSubgraphs(func_group);

src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,17 @@ struct Subgraph {
5959
};
6060
Gather _host_gather;
6161

62+
struct QuantUnpackGather {
63+
int64_t dst_idx = -1;
64+
65+
int64_t src_w_idx = -1;
66+
int64_t src_z_idx = -1;
67+
int64_t src_s_idx = -1;
68+
69+
int64_t idx_idx = -1;
70+
};
71+
QuantUnpackGather _quant_unpack_gather;
72+
6273
using Ref = std::reference_wrapper<Subgraph>;
6374
};
6475

0 commit comments

Comments
 (0)