Skip to content

Commit 991fbc8

Browse files
authored
refactor: InferenceDomainMapValuesのインスタンスをマクロで作る (VOICEVOX#852)
VOICEVOX#737 に向け。また VOICEVOX#851 の後にdecode.onnx入りのVVMに対応するときも同様に 役に立つはず。
1 parent f2e6b60 commit 991fbc8

File tree

11 files changed

+147
-81
lines changed

11 files changed

+147
-81
lines changed

Cargo.lock

Lines changed: 14 additions & 68 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ const_format = "0.2.33"
2727
cstr = "0.2.12" # https://github.com/dtolnay/syn/issues/1502
2828
derive-getters = "0.2.0"
2929
derive-new = "0.5.9"
30+
derive-syn-parse = "0.2.0"
3031
derive_more = "0.99.17"
3132
duct = "0.13.7"
3233
duplicate = "1.0.0"

crates/voicevox_core/src/infer.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ mod model_file;
33
pub(crate) mod runtimes;
44
pub(crate) mod session_set;
55

6-
use std::{borrow::Cow, collections::BTreeSet, fmt::Debug};
6+
use std::{borrow::Cow, collections::BTreeSet, fmt::Debug, ops::Index, sync::Arc};
77

88
use derive_new::new;
99
use duplicate::duplicate_item;
@@ -51,6 +51,7 @@ pub(crate) trait InferenceRuntime: 'static {
5151
/// 共に扱われるべき推論操作の集合を示す。
5252
pub(crate) trait InferenceDomain: Sized {
5353
type Operation: InferenceOperation;
54+
type Manifest: Index<Self::Operation, Output = Arc<str>>;
5455

5556
/// 対応する`StyleType`。
5657
///

crates/voicevox_core/src/infer/domains.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,12 @@ pub(crate) trait InferenceDomainMapValues {
6363
impl<T> InferenceDomainMapValues for (T,) {
6464
type Talk = T;
6565
}
66+
67+
macro_rules! inference_domain_map_values {
68+
(for<$arg:ident> $body:ty) => {
69+
(::macros::substitute_type!(
70+
$body where $arg = crate::infer::domains::TalkDomain as crate::infer::InferenceDomain
71+
),)
72+
};
73+
}
74+
pub(crate) use inference_domain_map_values;

crates/voicevox_core/src/infer/domains/talk.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use enum_map::Enum;
44
use macros::{InferenceInputSignature, InferenceOperation, InferenceOutputSignature};
55
use ndarray::{Array0, Array1, Array2};
66

7-
use crate::StyleType;
7+
use crate::{manifest::TalkManifest, StyleType};
88

99
use super::super::{
1010
InferenceDomain, InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor,
@@ -14,6 +14,7 @@ pub(crate) enum TalkDomain {}
1414

1515
impl InferenceDomain for TalkDomain {
1616
type Operation = TalkOperation;
17+
type Manifest = TalkManifest;
1718

1819
fn style_types() -> &'static BTreeSet<StyleType> {
1920
static STYLE_TYPES: LazyLock<BTreeSet<StyleType>> =

crates/voicevox_core/src/manifest.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use serde::{de, Deserialize, Deserializer, Serialize};
1212
use serde_with::{serde_as, DisplayFromStr};
1313

1414
use crate::{
15-
infer::domains::{InferenceDomainMap, TalkOperation},
15+
infer::domains::{inference_domain_map_values, InferenceDomainMap, TalkOperation},
1616
StyleId, VoiceModelId,
1717
};
1818

@@ -79,7 +79,7 @@ pub struct Manifest {
7979
domains: InferenceDomainMap<ManifestDomains>,
8080
}
8181

82-
pub(crate) type ManifestDomains = (Option<TalkManifest>,);
82+
pub(crate) type ManifestDomains = inference_domain_map_values!(for<D> Option<D::Manifest>);
8383

8484
#[derive(Deserialize, IndexForFields)]
8585
#[cfg_attr(test, derive(Default))]

crates/voicevox_core/src/status.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use itertools::iproduct;
99
use crate::{
1010
error::{ErrorRepr, LoadModelError, LoadModelErrorKind, LoadModelResult},
1111
infer::{
12-
domains::{InferenceDomainMap, TalkDomain, TalkOperation},
12+
domains::{inference_domain_map_values, InferenceDomainMap, TalkDomain},
1313
session_set::{InferenceSessionCell, InferenceSessionSet},
1414
InferenceDomain, InferenceInputSignature, InferenceRuntime, InferenceSessionOptions,
1515
InferenceSignature,
@@ -338,10 +338,11 @@ impl InferenceDomainMap<ModelBytesWithInnerVoiceIdsByDomain> {
338338
}
339339
}
340340

341-
type SessionOptionsByDomain = (EnumMap<TalkOperation, InferenceSessionOptions>,);
341+
type SessionOptionsByDomain =
342+
inference_domain_map_values!(for<D> EnumMap<D::Operation, InferenceSessionOptions>);
342343

343344
type SessionSetsWithInnerVoiceIdsByDomain<R> =
344-
(Option<(StyleIdToInnerVoiceId, InferenceSessionSet<R, TalkDomain>)>,);
345+
inference_domain_map_values!(for<D> Option<(StyleIdToInnerVoiceId, InferenceSessionSet<R, D>)>);
345346

346347
#[cfg(test)]
347348
mod tests {

crates/voicevox_core/src/voice_model.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ use crate::{
2323
asyncs::{Async, Mutex as _},
2424
error::{LoadModelError, LoadModelErrorKind, LoadModelResult},
2525
infer::{
26-
domains::{InferenceDomainMap, TalkDomain, TalkOperation},
26+
domains::{inference_domain_map_values, InferenceDomainMap, TalkDomain, TalkOperation},
2727
InferenceDomain,
2828
},
29-
manifest::{Manifest, ManifestDomains, StyleIdToInnerVoiceId, TalkManifest},
29+
manifest::{Manifest, ManifestDomains, StyleIdToInnerVoiceId},
3030
SpeakerMeta, StyleMeta, StyleType, VoiceModelMeta,
3131
};
3232

@@ -35,8 +35,9 @@ use crate::{
3535
/// [`VoiceModelId`]: VoiceModelId
3636
pub type RawVoiceModelId = Uuid;
3737

38-
pub(crate) type ModelBytesWithInnerVoiceIdsByDomain =
39-
(Option<(StyleIdToInnerVoiceId, EnumMap<TalkOperation, Vec<u8>>)>,);
38+
pub(crate) type ModelBytesWithInnerVoiceIdsByDomain = inference_domain_map_values!(
39+
for<D> Option<(StyleIdToInnerVoiceId, EnumMap<D::Operation, Vec<u8>>)>
40+
);
4041

4142
/// 音声モデルID。
4243
#[derive(
@@ -251,7 +252,7 @@ impl<A: Async> Inner<A> {
251252
}
252253

253254
type InferenceModelEntries<'manifest> =
254-
(Option<InferenceModelEntry<TalkDomain, &'manifest TalkManifest>>,);
255+
inference_domain_map_values!(for<D> Option<InferenceModelEntry<D, &'manifest D::Manifest>>);
255256

256257
struct InferenceModelEntry<D: InferenceDomain, M> {
257258
indices: EnumMap<D::Operation, usize>,

crates/voicevox_core_macros/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@ name = "macros"
1010
proc-macro = true
1111

1212
[dependencies]
13+
derive-syn-parse.workspace = true
1314
indexmap.workspace = true
1415
proc-macro2.workspace = true
1516
quote.workspace = true
16-
syn = { workspace = true, features = ["extra-traits", "full"] }
17+
syn = { workspace = true, features = ["extra-traits", "full", "visit-mut"] }
1718

1819
[lints.rust]
1920
unsafe_code = "forbid"
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
use derive_syn_parse::Parse;
2+
use quote::ToTokens as _;
3+
use syn::{
4+
parse_quote,
5+
visit_mut::{self, VisitMut},
6+
Path, PathArguments, PathSegment, Token, Type, TypePath,
7+
};
8+
9+
pub(crate) fn substitute_type(input: Substitution) -> syn::Result<proc_macro2::TokenStream> {
10+
let Substitution {
11+
mut body,
12+
arg,
13+
replacement,
14+
replacement_as,
15+
..
16+
} = input;
17+
18+
Substitute {
19+
arg,
20+
replacement,
21+
replacement_as,
22+
}
23+
.visit_type_mut(&mut body);
24+
25+
return Ok(body.to_token_stream());
26+
27+
struct Substitute {
28+
arg: syn::Ident,
29+
replacement: Path,
30+
replacement_as: Path,
31+
}
32+
33+
impl VisitMut for Substitute {
34+
fn visit_type_mut(&mut self, i: &mut Type) {
35+
visit_mut::visit_type_mut(self, i);
36+
37+
let Type::Path(TypePath {
38+
qself: None,
39+
path:
40+
Path {
41+
leading_colon: None,
42+
segments,
43+
},
44+
}) = i
45+
else {
46+
return;
47+
};
48+
49+
match &mut *segments.iter_mut().collect::<Vec<_>>() {
50+
[PathSegment {
51+
ident,
52+
arguments: PathArguments::None,
53+
}] if *ident == self.arg => {
54+
let replacement = self.replacement.clone();
55+
*i = parse_quote!(#replacement);
56+
}
57+
[PathSegment {
58+
ident: ident1,
59+
arguments: PathArguments::None,
60+
}, seg]
61+
if *ident1 == self.arg =>
62+
{
63+
let replacement = self.replacement.clone();
64+
let replacement_as = self.replacement_as.clone();
65+
*i = parse_quote!(<#replacement as #replacement_as>::#seg);
66+
}
67+
_ => {}
68+
}
69+
}
70+
}
71+
}
72+
73+
/// `$body:ty where $arg:ident = $replacement:path as $replacement_as:path`
74+
#[derive(Parse)]
75+
pub(crate) struct Substitution {
76+
body: Type,
77+
_where_token: Token![where],
78+
arg: syn::Ident,
79+
_eq_token: Token![=],
80+
replacement: Path,
81+
_as_token: Token![as],
82+
replacement_as: Path,
83+
}

0 commit comments

Comments
 (0)