diff --git a/sphinx/builders/__init__.py b/sphinx/builders/__init__.py index 7265ade3604..6796cbd32f3 100644 --- a/sphinx/builders/__init__.py +++ b/sphinx/builders/__init__.py @@ -37,7 +37,6 @@ from sphinx.config import Config from sphinx.events import EventManager from sphinx.util.tags import Tags - from sphinx.util.typing import NoneType logger = logging.getLogger(__name__) @@ -581,10 +580,6 @@ def _write_serial(self, docnames: Sequence[str]) -> None: self.write_doc(docname, doctree) def _write_parallel(self, docnames: Sequence[str], nproc: int) -> None: - def write_process(docs: list[tuple[str, nodes.document]]) -> None: - self.app.phase = BuildPhase.WRITING - for docname, doctree in docs: - self.write_doc(docname, doctree) # warm up caches/compile templates using the first document firstname, docnames = docnames[0], docnames[1:] @@ -602,17 +597,69 @@ def write_process(docs: list[tuple[str, nodes.document]]) -> None: progress = status_iterator(chunks, __('writing output... '), "darkgreen", len(chunks), self.app.verbosity) - def on_chunk_done(args: list[tuple[str, NoneType]], result: NoneType) -> None: - next(progress) - self.app.phase = BuildPhase.RESOLVING - for chunk in chunks: - arg = [] - for docname in chunk: - doctree = self.env.get_and_resolve_doctree(docname, self) - self.write_doc_serialized(docname, doctree) - arg.append((docname, doctree)) - tasks.add_task(write_process, arg, on_chunk_done) + + if not self.config.parallel_post_transform: + + # This is the "original" parallel write logic: + # only the final writing of the output is parallelised, + # not the application of post-transforms, etc + # The `write_doc` method should not add/modify any data + # required by the parent process + + def _write_doc(docs: list[tuple[str, nodes.document]]) -> None: + self.app.phase = BuildPhase.WRITING + for docname, doctree in docs: + self.write_doc(docname, doctree) + + def _on_chunk_done(args: list[tuple[str, None]], result: None) -> None: + next(progress) + + for chunk in chunks: + arg = [] + for docname in chunk: + doctree = self.env.get_and_resolve_doctree(docname, self) + self.write_doc_serialized(docname, doctree) + arg.append((docname, doctree)) + tasks.add_task(_write_doc, arg, _on_chunk_done) + + else: + + # This is the "new" parallel write logic; + # The entire logic is performed in parallel. + # However, certain data during this phase must be parsed back from child processes, + # to be used by the main process in the final build steps. + # This is achieved by allowing the builder and any subscribers to the events below, + # to (1) add data to a context, within the child process, + # (2) moving that context back to the parent process, via pickling, and + # (3) merge the data from context into the required location on the parent process + + logger.warning( + "parallel_post_transform is experimental " + "(add 'config.experimental' to suppress_warnings)", + type="config", + subtype="experimental" + ) + + def _write(docnames: list[str]) -> bytes: + for docname in docnames: + doctree = self.env.get_and_resolve_doctree(docname, self) + self.write_doc_serialized(docname, doctree) + self.app.phase = BuildPhase.WRITING + self.write_doc(docname, doctree) + context: dict[str, Any] = {} + self.parallel_write_data_retrieve(context, docnames) + self.events.emit('write-data-retrieve', self, context, docnames) + return pickle.dumps(context, pickle.HIGHEST_PROTOCOL) + + def _merge(docnames: list[str], context_bytes: bytes) -> None: + context: dict[str, Any] = pickle.loads(context_bytes) + self.parallel_write_data_merge(context, docnames) + self.events.emit('write-data-merge', self, context, docnames) + next(progress) + + for docnames in chunks: + tasks.add_task(_write, docnames, _merge) # make sure all threads have finished tasks.join() @@ -636,6 +683,24 @@ def write_doc_serialized(self, docname: str, doctree: nodes.document) -> None: """ pass + def parallel_write_data_retrieve( + self, context: dict[str, Any], docnames: list[str] + ) -> None: + """Retrieve data from child process of parallel write, + to be passed back to main process. + + :param context: Add data here to be passed back. + All data must be picklable. + :docnames: List of docnames that were written in the child process. + """ + + def parallel_write_data_merge(self, context: dict[str, Any], docnames: list[str]) -> None: + """Merge data from child process of parallel write into main process. + + :param context: Data from the child process. + :docnames: List of docnames that were written in the child process. + """ + def finish(self) -> None: """Finish the building process. diff --git a/sphinx/builders/html/__init__.py b/sphinx/builders/html/__init__.py index 75b0a394ba9..ad439bf9726 100644 --- a/sphinx/builders/html/__init__.py +++ b/sphinx/builders/html/__init__.py @@ -1147,6 +1147,20 @@ def update_page_context(self, pagename: str, templatename: str, ctx: dict, event_arg: Any) -> None: pass + def parallel_write_data_retrieve( + self, context: dict[str, Any], docnames: list[str] + ) -> None: + context['indexer'] = self.indexer + context['images'] = self.images + + def parallel_write_data_merge(self, context: dict[str, Any], docnames: list[str]) -> None: + if (indexer := context.get("indexer")) and self.indexer is not None: + # TODO can self.indexer be None if indexer is not None? + self.indexer.merge_other(indexer) + for filepath, filename in context['images'].items(): + if filepath not in self.images: + self.images[filepath] = filename + def handle_finish(self) -> None: self.finish_tasks.add_task(self.dump_search_index) self.finish_tasks.add_task(self.dump_inventory) diff --git a/sphinx/config.py b/sphinx/config.py index 1f4b47067ab..4bcd420fcd5 100644 --- a/sphinx/config.py +++ b/sphinx/config.py @@ -266,6 +266,7 @@ class Config: 'smartquotes_excludes': _Opt( {'languages': ['ja'], 'builders': ['man', 'text']}, 'env', ()), 'option_emphasise_placeholders': _Opt(False, 'env', ()), + 'parallel_post_transform': _Opt(False, 'env', ()), } def __init__(self, config: dict[str, Any] | None = None, diff --git a/sphinx/events.py b/sphinx/events.py index af8dfb4e2cf..38ef48b8d27 100644 --- a/sphinx/events.py +++ b/sphinx/events.py @@ -45,6 +45,8 @@ class EventListener(NamedTuple): 'warn-missing-reference': 'domain, node', 'doctree-resolved': 'doctree, docname', 'env-updated': 'env', + 'write-data-retrieve': 'builder, context, written docnames', + 'write-data-merge': 'builder, context, written docnames', 'build-finished': 'exception', } diff --git a/sphinx/search/__init__.py b/sphinx/search/__init__.py index 2638f92ffb4..1a26e437a0d 100644 --- a/sphinx/search/__init__.py +++ b/sphinx/search/__init__.py @@ -331,6 +331,20 @@ def dump(self, stream: IO, format: Any) -> None: format = self.formats[format] format.dump(self.freeze(), stream) + def merge_other(self, other: IndexBuilder) -> None: + """Merge another frozen index into this one.""" + # TODO test this + self._all_titles |= other._all_titles + self._filenames |= other._filenames + self._index_entries |= other._index_entries + self._mapping |= other._mapping + self._title_mapping |= other._title_mapping + self._titles |= other._titles + + def __getstate__(self): + # TODO improve this + return {k: v for k, v in self.__dict__.items() if k != 'env'} + def get_objects(self, fn2index: dict[str, int] ) -> dict[str, list[tuple[int, int, int, str, str]]]: rv: dict[str, list[tuple[int, int, int, str, str]]] = {}