Skip to content

Commit bc4ec92

Browse files
authored
Add full text feature builder (#125)
1 parent 4d5485f commit bc4ec92

File tree

3 files changed

+52
-8
lines changed

3 files changed

+52
-8
lines changed

src/trustyai/model/__init__.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -582,17 +582,17 @@ class _JFeature:
582582
"""Java Feature implicit methods"""
583583

584584
@property
585-
def name(self):
585+
def name(self) -> str:
586586
"""Return name"""
587587
return self.getName()
588588

589589
@property
590-
def type(self):
590+
def type(self) -> Type:
591591
"""Return type"""
592592
return self.getType()
593593

594594
@property
595-
def value(self):
595+
def value(self) -> Value:
596596
"""Return value"""
597597
return self.getValue()
598598

@@ -605,7 +605,7 @@ def domain(self):
605605
return _domain
606606

607607
@property
608-
def is_constrained(self):
608+
def is_constrained(self) -> bool:
609609
"""Return contraint"""
610610
return self.isConstrained()
611611

@@ -755,7 +755,19 @@ def output(name, dtype, value=None, score=1.0) -> _Output:
755755
return _Output(name, _type, Value(value), score)
756756

757757

758-
def feature(name: str, dtype: str, value=None, domain=None) -> Feature:
758+
def full_text_feature(
759+
name: str, value: str, tokenizer: Callable[[str], List[str]] = None
760+
) -> Feature:
761+
"""Create a full-text composite feature using TrustyAI methods"""
762+
return FeatureFactory.newFulltextFeature(name, value, tokenizer)
763+
764+
765+
def feature(
766+
name: str,
767+
dtype: str,
768+
value=None,
769+
domain=None,
770+
) -> Feature:
759771
"""Create a Java :class:`Feature`. The :class:`Feature` class is used to represent the
760772
individual components (or features) of input data points.
761773

src/trustyai/utils/text.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Utility methods for text data handling"""
2+
from typing import List, Callable
3+
4+
from jpype import _jclass
5+
6+
7+
def tokenizer(function: Callable[[str], List[str]]):
8+
"""Post-process outputs of a Python tokenizer function"""
9+
10+
def wrapper(_input: str):
11+
return _jclass.JClass("java.util.Arrays").asList(function(_input))
12+
13+
return wrapper

tests/general/test_conversions.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# pylint: disable=import-error, wrong-import-position, wrong-import-order, invalid-name
22
"""Implicit conversion test suite"""
3+
from typing import List
4+
35
from common import *
46

57
from jpype import _jclass
68

7-
from trustyai.model import feature
9+
from trustyai.model import feature, full_text_feature
810
from trustyai.model.domain import feature_domain
911
from trustyai.utils.data_conversions import (
1012
one_input_convert,
@@ -14,6 +16,7 @@
1416
)
1517
from org.kie.trustyai.explainability.model import Type
1618

19+
from trustyai.utils import text
1720

1821

1922
def test_list_python_to_java():
@@ -92,6 +95,7 @@ def test_categorical_domain_tuple():
9295
assert jdomain.getCategories().size() == 3
9396
assert jdomain.getCategories().containsAll(domain)
9497

98+
9599
def test_feature_function():
96100
"""Test helper method to create features"""
97101
f1 = feature(name="f-1", value=1.0, dtype="number")
@@ -114,6 +118,21 @@ def test_feature_function():
114118
assert f4.value.as_number() == 5
115119
assert f4.type == Type.CATEGORICAL
116120

121+
@text.tokenizer
122+
def tokenizer(x: str) -> List[str]:
123+
return x.split(" ")
124+
125+
values = "you just requested to change your password"
126+
f5 = full_text_feature(name="f-5", value=values, tokenizer=tokenizer)
127+
assert f5.name == "f-5"
128+
assert len(f5.value.as_obj()) == 7
129+
sub_features = f5.value.as_obj()
130+
tokens = values.split(" ")
131+
for i in range(7):
132+
assert sub_features[i].name == "f-5_" + str(i + 1)
133+
assert sub_features[i].value.as_string() == tokens[i]
134+
assert f5.type == Type.COMPOSITE
135+
117136

118137
def test_feature_domains():
119138
"""Test domains"""
@@ -248,7 +267,7 @@ def test_many_inputs_conversion_domained():
248267

249268
domain_bounds = [[np.random.rand(), np.random.rand()] for _ in range(n_feats)]
250269
domains = [feature_domain((lb, ub)) for lb, ub in domain_bounds]
251-
numpy1 = np.arange(0, n_feats*n_datapoints).reshape(-1, n_feats)
270+
numpy1 = np.arange(0, n_feats * n_datapoints).reshape(-1, n_feats)
252271
df = pd.DataFrame(numpy1, columns=["input-{}".format(i) for i in range(n_feats)])
253272

254273
ta_numpy1 = many_inputs_convert(numpy1, feature_domains=domains)
@@ -266,4 +285,4 @@ def test_many_inputs_conversion_domained():
266285
== domain_bounds[j][1]
267286

268287
for i in range(n_datapoints):
269-
assert ta_numpy1[i].equals(ta_df[i])
288+
assert ta_numpy1[i].equals(ta_df[i])

0 commit comments

Comments
 (0)