Skip to content

Commit ef4bd8a

Browse files
authored
Sigma transitions (#223)
* Add SigmaMatcher * First version working * Clean * Test sigma matcher * Use SymbolTable in Fst's Display * Save work * Save work * Save work * Save work * FFI sigma * Fmt python * Implement missing CDrop * Fix clippy * Clippy
1 parent 17316c8 commit ef4bd8a

File tree

9 files changed

+905
-138
lines changed

9 files changed

+905
-138
lines changed

rustfst-ffi/src/algorithms/compose.rs

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ use anyhow::{anyhow, Result};
22

33
use super::EnumConversionError;
44
use crate::fst::CFst;
5-
use crate::{get, wrap, RUSTFST_FFI_RESULT};
5+
use crate::{get, wrap, CLabel, RUSTFST_FFI_RESULT};
66

77
use ffi_convert::*;
8+
use rustfst::algorithms::compose::matchers::MatcherRewriteMode;
89
use rustfst::algorithms::compose::{
9-
compose, compose_with_config, ComposeConfig, ComposeFilterEnum,
10+
compose, compose_with_config, ComposeConfig, ComposeFilterEnum, MatcherConfig,
11+
SigmaMatcherConfig,
1012
};
1113
use rustfst::fst_impls::VectorFst;
1214
use rustfst::semirings::TropicalWeight;
@@ -50,21 +52,149 @@ impl CReprOf<ComposeFilterEnum> for CComposeFilterEnum {
5052
}
5153
}
5254

55+
#[derive(RawPointerConverter, Debug, Clone)]
56+
pub struct CMatcherRewriteMode(pub(crate) usize);
57+
58+
impl AsRust<MatcherRewriteMode> for CMatcherRewriteMode {
59+
fn as_rust(&self) -> Result<MatcherRewriteMode, AsRustError> {
60+
match self.0 {
61+
0 => Ok(MatcherRewriteMode::MatcherRewriteAuto),
62+
1 => Ok(MatcherRewriteMode::MatcherRewriteAlways),
63+
2 => Ok(MatcherRewriteMode::MatcherRewriteNever),
64+
_ => Err(AsRustError::Other(Box::new(EnumConversionError {}))),
65+
}
66+
}
67+
}
68+
69+
impl CDrop for CMatcherRewriteMode {
70+
fn do_drop(&mut self) -> Result<(), CDropError> {
71+
Ok(())
72+
}
73+
}
74+
75+
impl CReprOf<MatcherRewriteMode> for CMatcherRewriteMode {
76+
fn c_repr_of(value: MatcherRewriteMode) -> Result<CMatcherRewriteMode, CReprOfError> {
77+
let variant = match value {
78+
MatcherRewriteMode::MatcherRewriteAuto => 0,
79+
MatcherRewriteMode::MatcherRewriteAlways => 1,
80+
MatcherRewriteMode::MatcherRewriteNever => 2,
81+
};
82+
Ok(CMatcherRewriteMode(variant))
83+
}
84+
}
85+
86+
#[derive(AsRust, CReprOf, CDrop, RawPointerConverter, Debug, Clone)]
87+
#[target_type(SigmaMatcherConfig)]
88+
pub struct CSigmaMatcherConfig {
89+
pub sigma_label: CLabel,
90+
pub rewrite_mode: CMatcherRewriteMode,
91+
}
92+
93+
#[derive(RawPointerConverter, Debug, Clone, Default)]
94+
pub struct CMatcherConfig {
95+
pub sigma_matcher_config: Option<CSigmaMatcherConfig>,
96+
}
97+
98+
impl AsRust<MatcherConfig> for CMatcherConfig {
99+
fn as_rust(&self) -> Result<MatcherConfig, AsRustError> {
100+
if let Some(v) = &self.sigma_matcher_config {
101+
Ok(MatcherConfig {
102+
sigma_matcher_config: Some(v.as_rust()?),
103+
})
104+
} else {
105+
Ok(MatcherConfig {
106+
sigma_matcher_config: None,
107+
})
108+
}
109+
}
110+
}
111+
112+
impl CDrop for CMatcherConfig {
113+
fn do_drop(&mut self) -> Result<(), CDropError> {
114+
self.sigma_matcher_config
115+
.as_mut()
116+
.map(|v| v.do_drop())
117+
.transpose()?;
118+
Ok(())
119+
}
120+
}
121+
122+
impl CReprOf<MatcherConfig> for CMatcherConfig {
123+
fn c_repr_of(input: MatcherConfig) -> Result<Self, CReprOfError> {
124+
if let Some(v) = input.sigma_matcher_config {
125+
Ok(Self {
126+
sigma_matcher_config: Some(CReprOf::c_repr_of(v)?),
127+
})
128+
} else {
129+
Ok(Self {
130+
sigma_matcher_config: None,
131+
})
132+
}
133+
}
134+
}
135+
53136
#[derive(AsRust, CReprOf, CDrop, RawPointerConverter, Debug)]
54137
#[target_type(ComposeConfig)]
55138
pub struct CComposeConfig {
56139
pub compose_filter: CComposeFilterEnum,
57140
pub connect: bool,
141+
pub matcher1_config: CMatcherConfig,
142+
pub matcher2_config: CMatcherConfig,
143+
}
144+
145+
#[no_mangle]
146+
pub extern "C" fn fst_matcher_config_new(
147+
sigma_label: libc::size_t,
148+
rewrite_mode: libc::size_t,
149+
config: *mut *const CMatcherConfig,
150+
) -> RUSTFST_FFI_RESULT {
151+
wrap(|| {
152+
let matcher_config = CMatcherConfig {
153+
sigma_matcher_config: Some(CSigmaMatcherConfig {
154+
sigma_label: sigma_label as CLabel,
155+
rewrite_mode: CMatcherRewriteMode(rewrite_mode as usize),
156+
}),
157+
};
158+
159+
unsafe { *config = matcher_config.into_raw_pointer() };
160+
Ok(())
161+
})
58162
}
59163

60164
#[no_mangle]
61165
pub extern "C" fn fst_compose_config_new(
62166
compose_filter: libc::size_t,
63167
connect: bool,
168+
matcher1_config: *const CMatcherConfig,
169+
matcher2_config: *const CMatcherConfig,
64170
config: *mut *const CComposeConfig,
65171
) -> RUSTFST_FFI_RESULT {
66172
wrap(|| {
173+
let matcher1_config = if matcher1_config.is_null() {
174+
CMatcherConfig::default()
175+
} else {
176+
unsafe {
177+
<CMatcherConfig as ffi_convert::RawBorrow<CMatcherConfig>>::raw_borrow(
178+
matcher1_config,
179+
)?
180+
}
181+
.clone()
182+
};
183+
184+
let matcher2_config = if matcher2_config.is_null() {
185+
CMatcherConfig::default()
186+
} else {
187+
unsafe {
188+
<CMatcherConfig as ffi_convert::RawBorrow<CMatcherConfig>>::raw_borrow(
189+
matcher2_config,
190+
)?
191+
}
192+
.clone()
193+
};
194+
67195
let compose_config = CComposeConfig {
196+
matcher1_config,
197+
matcher2_config,
68198
compose_filter: CComposeFilterEnum(compose_filter as usize),
69199
connect,
70200
};

rustfst-python/rustfst/algorithms/compose.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22
import ctypes
3+
from typing import Optional
4+
35
from rustfst.ffi_utils import (
46
lib,
57
check_ffi_error,
@@ -11,6 +13,25 @@
1113
from enum import Enum
1214

1315

16+
class MatcherRewriteMode(Enum):
17+
AUTO = 0
18+
ALWAYS = 1
19+
NEVER = 2
20+
21+
22+
class MatcherConfig:
23+
def __init__(self, sigma_label: int, rewrite_mode: MatcherRewriteMode):
24+
config = ctypes.pointer(ctypes.c_void_p())
25+
ret_code = lib.fst_matcher_config_new(
26+
ctypes.c_size_t(sigma_label),
27+
ctypes.c_size_t(rewrite_mode.value),
28+
ctypes.byref(config),
29+
)
30+
err_msg = "Error creating MatcherConfig"
31+
check_ffi_error(ret_code, err_msg)
32+
self.ptr = config
33+
34+
1435
class ComposeFilter(Enum):
1536
AUTOFILTER = 0
1637
NULLFILTER = 1
@@ -22,21 +43,33 @@ class ComposeFilter(Enum):
2243

2344

2445
class ComposeConfig:
25-
def __init__(self, compose_filter=None, connect: bool = None):
26-
if compose_filter and connect is None:
27-
self.ptr = compose_filter
28-
elif compose_filter and connect:
29-
config = ctypes.pointer(ctypes.c_void_p())
30-
ret_code = lib.fst_compose_config_new(
31-
ctypes.c_size_t(compose_filter.value),
32-
ctypes.c_bool(connect),
33-
ctypes.byref(config),
34-
)
35-
err_msg = "Error creating ComposeConfig"
36-
check_ffi_error(ret_code, err_msg)
37-
self.ptr = config
38-
else:
39-
raise ValueError("Could not create ComposeConfig")
46+
def __init__(
47+
self,
48+
compose_filter: ComposeFilter = ComposeFilter.AUTOFILTER,
49+
connect: bool = True,
50+
matcher1_config: Optional[MatcherConfig] = None,
51+
matcher2_config: Optional[MatcherConfig] = None,
52+
):
53+
config = ctypes.pointer(ctypes.c_void_p())
54+
55+
m1_ptr = None
56+
if matcher1_config is not None:
57+
m1_ptr = matcher1_config.ptr
58+
m2_ptr = None
59+
60+
if matcher2_config is not None:
61+
m2_ptr = matcher2_config.ptr
62+
63+
ret_code = lib.fst_compose_config_new(
64+
ctypes.c_size_t(compose_filter.value),
65+
ctypes.c_bool(connect),
66+
m1_ptr,
67+
m2_ptr,
68+
ctypes.byref(config),
69+
)
70+
err_msg = "Error creating ComposeConfig"
71+
check_ffi_error(ret_code, err_msg)
72+
self.ptr = config
4073

4174

4275
def compose(fst: VectorFst, other_fst: VectorFst) -> VectorFst:

rustfst-python/tests/algorithms/test_compose.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
from rustfst import VectorFst, Tr
2-
from rustfst.algorithms.compose import ComposeFilter, ComposeConfig
2+
from rustfst.algorithms.compose import (
3+
ComposeFilter,
4+
ComposeConfig,
5+
MatcherConfig,
6+
MatcherRewriteMode,
7+
compose_with_config,
8+
)
9+
from rustfst.symbol_table import SymbolTable
10+
from rustfst.algorithms import acceptor
311

412

513
def test_compose_fst():
@@ -144,3 +152,26 @@ def test_compose_config():
144152

145153
fst3 = fst1.compose(fst2, compose_config)
146154
assert fst3 == expected_fst
155+
156+
157+
def test_sigma_compose():
158+
symt = SymbolTable.from_symbols(
159+
["<eps>", "play", "david", "queen", "please", "<sigma>"]
160+
)
161+
162+
query_fst = acceptor("play queen please", symt)
163+
sigma_fst = acceptor("play <sigma> please", symt)
164+
165+
matcher_config_right = MatcherConfig(
166+
sigma_label=symt.find("<sigma>"), rewrite_mode=MatcherRewriteMode.AUTO
167+
)
168+
169+
compose_config = ComposeConfig(
170+
compose_filter=ComposeFilter.SEQUENCEFILTER,
171+
connect=True,
172+
matcher2_config=matcher_config_right,
173+
)
174+
175+
res = compose_with_config(query_fst, sigma_fst, compose_config)
176+
177+
assert res == query_fst

0 commit comments

Comments
 (0)