Skip to content

Commit e813d2a

Browse files
committed
Add reading of setuptools metadata to find smart_open transport / compressor extensions
1 parent fe6cf99 commit e813d2a

File tree

5 files changed

+145
-1
lines changed

5 files changed

+145
-1
lines changed

smart_open/compression.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#
88
"""Implements the compression layer of the ``smart_open`` library."""
99
import logging
10+
import importlib
11+
import importlib.metadata
1012
import os.path
1113

1214
logger = logging.getLogger(__name__)
@@ -145,3 +147,16 @@ def compression_wrapper(file_obj, mode, compression):
145147
#
146148
register_compressor('.bz2', _handle_bz2)
147149
register_compressor('.gz', _handle_gzip)
150+
151+
152+
def _register_compressor_entry_point(ep):
153+
try:
154+
assert len(ep.name) > 0, "At least one char is required for ep.name"
155+
extension = ".{}".format(ep.name)
156+
register_compressor(extension, ep.load())
157+
except Exception:
158+
logger.warning("Fail to load smart_open compressor extension: %s (target: %s)", ep.name, ep.value)
159+
160+
161+
for ep in importlib.metadata.entry_points().select(group='smart_open_compressor'):
162+
_register_compressor_entry_point(ep)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# -*- coding: utf-8 -*-
2+
"""Some no-op compressor"""
3+
4+
5+
def handle_foo():
6+
...
7+
8+
9+
def handle_bar():
10+
...
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from importlib.metadata import EntryPoint
4+
import pytest
5+
6+
from smart_open.compression import _COMPRESSOR_REGISTRY, _register_compressor_entry_point
7+
8+
9+
def unregister_compressor(ext):
10+
if ext in _COMPRESSOR_REGISTRY:
11+
del _COMPRESSOR_REGISTRY[ext]
12+
13+
14+
@pytest.fixture(autouse=True)
15+
def cleanup_compressor():
16+
unregister_compressor(".foo")
17+
unregister_compressor(".bar")
18+
19+
20+
def test_register_valid_entry_point():
21+
assert ".foo" not in _COMPRESSOR_REGISTRY
22+
assert ".bar" not in _COMPRESSOR_REGISTRY
23+
_register_compressor_entry_point(EntryPoint(
24+
"foo",
25+
"smart_open.tests.fixtures.compressor:handle_bar",
26+
"smart_open_compressor",
27+
))
28+
_register_compressor_entry_point(EntryPoint(
29+
"bar",
30+
"smart_open.tests.fixtures.compressor:handle_bar",
31+
"smart_open_compressor",
32+
))
33+
assert ".foo" in _COMPRESSOR_REGISTRY
34+
assert ".bar" in _COMPRESSOR_REGISTRY
35+
36+
37+
def test_register_invalid_entry_point_name_do_not_crash():
38+
_register_compressor_entry_point(EntryPoint(
39+
"",
40+
"smart_open.tests.fixtures.compressor:handle_foo",
41+
"smart_open_compressor",
42+
))
43+
assert "" not in _COMPRESSOR_REGISTRY
44+
assert "." not in _COMPRESSOR_REGISTRY
45+
46+
47+
def test_register_invalid_entry_point_value_do_not_crash():
48+
_register_compressor_entry_point(EntryPoint(
49+
"foo",
50+
"smart_open.tests.fixtures.compressor:handle_invalid",
51+
"smart_open_compressor",
52+
))
53+
assert ".foo" not in _COMPRESSOR_REGISTRY

smart_open/tests/test_transport.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,44 @@
11
# -*- coding: utf-8 -*-
2+
from importlib.metadata import EntryPoint
23
import pytest
34
import unittest
45

5-
from smart_open.transport import register_transport, get_transport
6+
from smart_open.transport import (
7+
register_transport, get_transport, _REGISTRY, _ERRORS, _register_transport_entry_point
8+
)
9+
10+
11+
def unregister_transport(x):
12+
if x in _REGISTRY:
13+
del _REGISTRY[x]
14+
if x in _ERRORS:
15+
del _ERRORS[x]
16+
17+
18+
def assert_transport_not_registered(scheme):
19+
with pytest.raises(NotImplementedError):
20+
get_transport(scheme)
21+
22+
23+
def assert_transport_registered(scheme):
24+
transport = get_transport(scheme)
25+
assert transport.SCHEME == scheme
626

727

828
class TransportTest(unittest.TestCase):
29+
def tearDown(self):
30+
unregister_transport("foo")
31+
unregister_transport("missing")
932

1033
def test_registry_requires_declared_schemes(self):
1134
with pytest.raises(ValueError):
1235
register_transport('smart_open.tests.fixtures.no_schemes_transport')
1336

37+
def test_registry_valid_transport(self):
38+
assert_transport_not_registered("foo")
39+
register_transport('smart_open.tests.fixtures.good_transport')
40+
assert_transport_registered("foo")
41+
1442
def test_registry_errors_on_double_register_scheme(self):
1543
register_transport('smart_open.tests.fixtures.good_transport')
1644
with pytest.raises(AssertionError):
@@ -20,3 +48,29 @@ def test_registry_errors_get_transport_for_module_with_missing_deps(self):
2048
register_transport('smart_open.tests.fixtures.missing_deps_transport')
2149
with pytest.raises(ImportError):
2250
get_transport("missing")
51+
52+
def test_register_entry_point_valid(self):
53+
assert_transport_not_registered("foo")
54+
_register_transport_entry_point(EntryPoint(
55+
"foo",
56+
"smart_open.tests.fixtures.good_transport",
57+
"smart_open_transport",
58+
))
59+
assert_transport_registered("foo")
60+
61+
def test_register_entry_point_catch_bad_data(self):
62+
_register_transport_entry_point(EntryPoint(
63+
"invalid",
64+
"smart_open.some_totaly_invalid_module",
65+
"smart_open_transport",
66+
))
67+
68+
def test_register_entry_point_for_module_with_missing_deps(self):
69+
assert_transport_not_registered("missing")
70+
_register_transport_entry_point(EntryPoint(
71+
"missing",
72+
"smart_open.tests.fixtures.missing_deps_transport",
73+
"smart_open_transport",
74+
))
75+
with pytest.raises(ImportError):
76+
get_transport("missing")

smart_open/transport.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
1212
"""
1313
import importlib
14+
import importlib.metadata
1415
import logging
1516

1617
import smart_open.local_file
@@ -102,5 +103,16 @@ def get_transport(scheme):
102103
register_transport('smart_open.ssh')
103104
register_transport('smart_open.webhdfs')
104105

106+
107+
def _register_transport_entry_point(ep):
108+
try:
109+
register_transport(ep.value)
110+
except Exception:
111+
logger.warning("Fail to load smart_open transport extension: %s (target: %s)", ep.name, ep.value)
112+
113+
114+
for ep in importlib.metadata.entry_points().select(group='smart_open_transport'):
115+
_register_transport_entry_point(ep)
116+
105117
SUPPORTED_SCHEMES = tuple(sorted(_REGISTRY.keys()))
106118
"""The transport schemes that the local installation of ``smart_open`` supports."""

0 commit comments

Comments
 (0)