From 508fdf622a980b568d4c217fc04676d784518980 Mon Sep 17 00:00:00 2001 From: Alex Light Date: Wed, 24 Sep 2025 13:52:23 -0700 Subject: [PATCH] `NodeMap`: an interned Map implementation. This is a std::map-alike for specifically maps with Node* as keys which holds the mapped value logically interned into the node itself. This is substantially faster than a flat-hash-map in many cases with speedups of up to 10% seen in optimization time on some targets when integrated into a few passes. This implementation is specifically made so that it can (mostly) function as a drop in replacement for flat_hash_map. There are some notable differences however: 1) All operations perform a pointer read on the key argument. This means calling practically any function with an invalid Node* is UB. 2) reserve(...) is not available (since it does nothing) 3) iteration order is defined as the most-to-least-recently added. 4) Each node with any data associated with it on any map has a vector with up to {max number of live maps} pointers associated with it. This extra space is never cleaned up on the assumption that only a relatively small number of maps will ever be simultaneously live. 5) Once a node has been deleted from a function/package/etc all data associated with it on any map is deallocated. The node will not be visible in any way in the map (ie size will go down by 1) and it will be UB to try to read anything about that key. 6) All keys used in a NodeMap must be from the same package. 7) NodeMap has pointer, reference and iterator stability (adding/removing new values to the map does not change the pointer to the value or invalidate any iterators or references except those pointing to the exact entry removed). Any read of the map requires 4 pointer reads (node -> user-data-vector -> map node -> value). This map is (very slightly) less efficient than flat_hash_map if the value type is trivially copyable. To ensure that this is respected by default the map will have a compile error if you attempt to use it with these sorts of values. PiperOrigin-RevId: 811020776 --- GEMINI.md | 7 + xls/common/BUILD | 5 + xls/common/pointer_utils.h | 38 ++ xls/data_structures/inline_bitmap.h | 7 +- xls/ir/BUILD | 40 ++ xls/ir/node.cc | 25 ++ xls/ir/node.h | 33 ++ xls/ir/node_map.h | 634 ++++++++++++++++++++++++++++ xls/ir/node_map_test.cc | 408 ++++++++++++++++++ xls/ir/package.cc | 48 ++- xls/ir/package.h | 21 + 11 files changed, 1262 insertions(+), 4 deletions(-) create mode 100644 xls/common/pointer_utils.h create mode 100644 xls/ir/node_map.h create mode 100644 xls/ir/node_map_test.cc diff --git a/GEMINI.md b/GEMINI.md index 57931f957f..eae9943099 100644 --- a/GEMINI.md +++ b/GEMINI.md @@ -26,3 +26,10 @@ This file provides a context for the XLS project. **Documentation:** * OSS docs are in `docs_src` and rendered with `mkdocs` at [https://google.github.io/xls/](https://google.github.io/xls/). + +**‼️ Agent Instructions ‼️** + +* **NodeMap/NodeSet Usage:** New code should generally use `NodeMap` or + `NodeSet` instead of `absl::flat_hash_map` or `absl::flat_hash_set` when the + key is `Node*` and the value is not trivially copyable. Existing code should + not be modified unless specifically directed to do so. diff --git a/xls/common/BUILD b/xls/common/BUILD index 835d488a55..e26ae38f25 100644 --- a/xls/common/BUILD +++ b/xls/common/BUILD @@ -640,6 +640,11 @@ cc_library( ], ) +cc_library( + name = "pointer_utils", + hdrs = ["pointer_utils.h"], +) + cc_test( name = "stopwatch_test", srcs = ["stopwatch_test.cc"], diff --git a/xls/common/pointer_utils.h b/xls/common/pointer_utils.h new file mode 100644 index 0000000000..b046cbf5e7 --- /dev/null +++ b/xls/common/pointer_utils.h @@ -0,0 +1,38 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef XLS_COMMON_POINTER_UTILS_H_ +#define XLS_COMMON_POINTER_UTILS_H_ + +#include + +namespace xls { + +namespace internal { +using RawDeletePtr = void (*)(void*); +} + +// type-erased version of unique ptr that keeps track of the appropriate +// destructor. To use reinterpret_cast(x.get()). +using TypeErasedUniquePtr = std::unique_ptr; +template > + +TypeErasedUniquePtr EraseType(std::unique_ptr ptr) { + return TypeErasedUniquePtr( + ptr.release(), [](void* ptr) { delete reinterpret_cast(ptr); }); +} + +} // namespace xls + +#endif // XLS_COMMON_POINTER_UTILS_H_ diff --git a/xls/data_structures/inline_bitmap.h b/xls/data_structures/inline_bitmap.h index c60f4c6ef8..c2c13efd32 100644 --- a/xls/data_structures/inline_bitmap.h +++ b/xls/data_structures/inline_bitmap.h @@ -39,6 +39,10 @@ class BitmapView; // A bitmap that has 64-bits of inline storage by default. class InlineBitmap { public: + // How many bits are held in one word. + static constexpr int64_t kWordBits = 64; + // How many bytes are held in one word. + static constexpr int64_t kWordBytes = 8; // Constructs an InlineBitmap of width `bit_count` using the bits in // `word`. If `bit_count` is greater than 64, then all high bits are set to // `fill`. @@ -345,9 +349,6 @@ class InlineBitmap { friend uint64_t GetWordBitsAtForTest(const InlineBitmap& ib, int64_t bit_offset); - static constexpr int64_t kWordBits = 64; - static constexpr int64_t kWordBytes = 8; - // Gets the kWordBits bits following bit_offset with 'Get(bit_offset)' being // the LSB, Get(bit_offset + 1) being the next lsb etc. int64_t GetWordBitsAt(int64_t bit_offset) const; diff --git a/xls/ir/BUILD b/xls/ir/BUILD index a4fabda9ed..345fa5ea89 100644 --- a/xls/ir/BUILD +++ b/xls/ir/BUILD @@ -638,9 +638,11 @@ cc_library( "//xls/common:casts", "//xls/common:iterator_range", "//xls/common:math_util", + "//xls/common:pointer_utils", "//xls/common:visitor", "//xls/common/status:ret_check", "//xls/common/status:status_macros", + "//xls/data_structures:inline_bitmap", "//xls/data_structures:leaf_type_tree", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -2491,3 +2493,41 @@ cc_test( "@googletest//:gtest", ], ) + +cc_library( + name = "node_map", + hdrs = ["node_map.h"], + deps = [ + ":ir", + "//xls/common:pointer_utils", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:hash_container_defaults", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@cppitertools", + ], +) + +cc_test( + name = "node_map_test", + srcs = ["node_map_test.cc"], + deps = [ + ":benchmark_support", + ":bits", + ":function_builder", + ":ir", + ":ir_matcher", + ":ir_test_base", + ":node_map", + "//xls/common:xls_gunit_main", + "//xls/common/status:matchers", + "//xls/common/status:status_macros", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@google_benchmark//:benchmark", + "@googletest//:gtest", + ], +) diff --git a/xls/ir/node.cc b/xls/ir/node.cc index da8e82add1..4ada070b3b 100644 --- a/xls/ir/node.cc +++ b/xls/ir/node.cc @@ -36,6 +36,7 @@ #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xls/common/casts.h" +#include "xls/common/pointer_utils.h" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" #include "xls/ir/change_listener.h" @@ -997,4 +998,28 @@ bool Node::OpIn(absl::Span choices) const { Package* Node::package() const { return function_base()->package(); } +std::optional Node::TakeUserData(int64_t idx) { + DCHECK(function_base()->package()->IsLiveUserDataId(idx)) << idx; + if (user_data_.size() <= idx) { + return std::nullopt; + } + std::optional data = std::move(user_data_[idx]); + user_data_[idx] = std::nullopt; + return data; +} +void* Node::GetUserData(int64_t idx) { + DCHECK(function_base()->package()->IsLiveUserDataId(idx)) << idx; + if (user_data_.size() <= idx) { + user_data_.resize(idx + 1); + } + const auto& v = user_data_[idx]; + return v ? v->get() : nullptr; +} +void Node::SetUserData(int64_t idx, TypeErasedUniquePtr data) { + DCHECK(function_base()->package()->IsLiveUserDataId(idx)) << idx; + if (user_data_.size() <= idx) { + user_data_.resize(idx + 1); + } + user_data_[idx] = std::move(data); +} } // namespace xls diff --git a/xls/ir/node.h b/xls/ir/node.h index 4b79928eff..b4c1273b84 100644 --- a/xls/ir/node.h +++ b/xls/ir/node.h @@ -24,6 +24,7 @@ #include #include #include +#include #include "absl/container/inlined_vector.h" #include "absl/log/check.h" @@ -31,6 +32,7 @@ #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xls/common/casts.h" +#include "xls/common/pointer_utils.h" #include "xls/common/status/status_macros.h" #include "xls/ir/change_listener.h" #include "xls/ir/op.h" @@ -316,6 +318,28 @@ class Node { absl::Format(&sink, "%s", node.GetName()); } + // User-data access functions. Should not be directly used. Use NodeMap + // instead. + // + // Extreme care should be used when interacting with these functions and the + // Package ones since this is basically doing manual memory management. + + // Get the pointer associated with this indexes user data or nullptr if + // never set. Use HasUserData to see if anything has ever been set. + // + // idx must be a value returned by Package::AllocateNodeUserData which has not + // had ReleaseNodeUserDataId called on it. + void* GetUserData(int64_t idx); + // Sets user data at idx to 'data'. + void SetUserData(int64_t idx, TypeErasedUniquePtr data); + // Removes user data at idx from the node. Returns std::nullopt if nothing has + // been set. + std::optional TakeUserData(int64_t idx); + // Checks if anything has ever been set at the given user data. + bool HasUserData(int64_t idx) { + return user_data_.size() > idx && user_data_[idx].has_value(); + } + protected: // FunctionBase needs to be a friend to access RemoveUser for deleting nodes // from the graph. @@ -368,6 +392,15 @@ class Node { // Set of users sorted by node_id for stability. absl::InlinedVector users_; + + private: + std::vector> user_data_; + + // Clear all user data. + void ClearUserData() { user_data_.clear(); }; + + // for ClearUserData + friend class Package; }; inline NodeRef::NodeRef(Node* node) diff --git a/xls/ir/node_map.h b/xls/ir/node_map.h new file mode 100644 index 0000000000..801ba24699 --- /dev/null +++ b/xls/ir/node_map.h @@ -0,0 +1,634 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef XLS_IR_NODE_MAP_H_ +#define XLS_IR_NODE_MAP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/hash_container_defaults.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "cppitertools/imap.hpp" +#include "xls/common/pointer_utils.h" +#include "xls/ir/block.h" // IWYU pragma: keep +#include "xls/ir/function.h" // IWYU pragma: keep +#include "xls/ir/node.h" +#include "xls/ir/package.h" +#include "xls/ir/proc.h" // IWYU pragma: keep + +namespace xls { + +using ForceAllowNodeMap = std::false_type; +namespace internal { +template +struct ErrorOnSlower { + static_assert( + !kIsLikelySlower::value, + "NodeMap is likely slower than absl::flat_hash_map for this Value type " + "because it is trivially copyable. To override declare node map as " + "NodeMap<{ValueT}, ForceAllowNodeMap>. Care should be taken to validate " + "that this is actually a performance win however."); +}; +}; // namespace internal + +// `xls::NodeMap` is a map-like interface for holding mappings from `xls::Node*` +// to `ValueT`. It is designed to be a partial drop-in replacement for +// `absl::flat_hash_map` but with better performance in XLS +// workloads. +// +// `NodeMap` achieves this performance by storing `ValueT` logically within the +// `Node` object itself as 'user-data'. This avoids hashing `Node*` and reduces +// cache misses compared to `absl::flat_hash_map`. A read of a value requires +// only 4 pointer reads. +// +// Notable Differences from `absl::flat_hash_map`: +// +// * All operations inherently perform pointer reads on any Node* typed values +// in key positions. This means that attempting to use deallocated Node*s as +// keys **in any way** (including just calling contains, etc) is UB. +// * The node-map has pointer stability of its values as well as iterator +// stability (except for iterators pointing to an entry which is removed +// either by a call to erase or by removing the node which is the entries +// key). +// * `reserve()` is not available as `NodeMap` does not require upfront storage +// allocation in the same way as `absl::flat_hash_map`. +// * Iteration order is from most-recently inserted to least-recently inserted. +// * All keys in a `NodeMap` must come from the same `Package`. +// * If a `Node` is deleted from its function/package, any associated data in +// any `NodeMap` is deallocated and it is removed from the map. +// * Each node with data in any `NodeMap` has an internal vector to hold +// user-data for all live maps. This extra space is not cleaned up until +// package destruction, based on the assumption that only a small number of +// maps will be simultaneously live. +// +// WARNING: This map is not thread safe. Also destruction of a node which has a +// value mapped to it is a modification of the map and needs to be taken into +// account if using this map in a multi-threaded context. +// +// NB This does a lot of very unsafe stuff internally to store the data using +// node user-data. +template > +class NodeMap : public internal::ErrorOnSlower { + private: + // Intrusive list node to hold the actual data allowing us to iterate. + struct DataHolder { + template + DataHolder(Node* n, Args&&... args) + : value(std::piecewise_construct, std::forward_as_tuple(n), + std::forward_as_tuple(std::forward(args)...)), + iter() {} + + ~DataHolder() { + // Remove itself from the list on deletion. + if (configured_list) { + configured_list->erase(iter); + } + } + + std::pair value; + // Intrusive list node to allow for iteration that's somewhat fast. + std::list::iterator iter; + std::list* configured_list = nullptr; + }; + + class ConstIterator; + class Iterator { + public: + using difference_type = ptrdiff_t; + using value_type = std::pair; + using reference = value_type&; + using pointer = value_type*; + using element_type = value_type; + using const_reference = const value_type&; + using iterator_category = std::forward_iterator_tag; + + Iterator() : iter_() {} + Iterator(std::list::iterator iter) : iter_(iter) {} + Iterator(const Iterator& other) : iter_(other.iter_) {} + Iterator& operator=(const Iterator& other) { + iter_ = other.iter_; + return *this; + } + Iterator& operator++() { + ++iter_; + return *this; + } + Iterator operator++(int) { + Iterator tmp = *this; + ++(*this); + return tmp; + } + reference operator*() const { return (*iter_)->value; } + pointer get() const { return &**this; } + pointer operator->() const { return &(*iter_)->value; } + friend bool operator==(const Iterator& a, const Iterator& b) { + return a.iter_ == b.iter_; + } + friend bool operator!=(const Iterator& a, const Iterator& b) { + return !(a == b); + } + + private: + std::list::iterator iter_; + friend class ConstIterator; + }; + class ConstIterator { + public: + using difference_type = ptrdiff_t; + using value_type = std::pair; + using reference = const value_type&; + using pointer = const value_type*; + using element_type = value_type; + using const_reference = const value_type&; + using iterator_category = std::forward_iterator_tag; + ConstIterator() : iter_() {} + ConstIterator(std::list::const_iterator iter) : iter_(iter) {} + ConstIterator(std::list::iterator iter) : iter_(iter) {} + ConstIterator(const ConstIterator& other) : iter_(other.iter_) {} + ConstIterator(const Iterator& other) : iter_(other.iter_) {} + ConstIterator& operator=(const ConstIterator& other) { + iter_ = other.iter_; + return *this; + } + ConstIterator& operator++() { + ++iter_; + return *this; + } + ConstIterator operator++(int) { + ConstIterator tmp = *this; + ++(*this); + return tmp; + } + reference operator*() const { return (*iter_)->value; } + pointer get() const { return &**this; } + pointer operator->() const { return &(*iter_)->value; } + friend bool operator==(const ConstIterator& a, const ConstIterator& b) { + return a.iter_ == b.iter_; + } + friend bool operator!=(const ConstIterator& a, const ConstIterator& b) { + return !(a == b); + } + + private: + std::list::const_iterator iter_; + }; + + public: + using size_type = size_t; + using difference_type = ptrdiff_t; + using key_equal = absl::DefaultHashContainerEq; + using value_type = std::pair; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + using iterator = Iterator; + using const_iterator = ConstIterator; + + // Creates an empty `NodeMap` associated with the given `Package`. + explicit NodeMap(Package* pkg) + : pkg_(pkg), + id_(pkg->AllocateNodeUserDataId()), + values_(std::make_unique>()) {} + // Creates an empty `NodeMap` which will become associated with a `Package` + // upon first insertion. + NodeMap() + : pkg_(nullptr), + id_(-1), + values_(std::make_unique>()) {} + // Releases all data held by this map and informs the package that this + // map's user-data ID can be reused. + ~NodeMap() { + // Release all the data. + clear(); + } + // Copy constructor. + NodeMap(const NodeMap& other) + : pkg_(other.pkg_), + id_(pkg_ != nullptr ? pkg_->AllocateNodeUserDataId() : -1), + values_(std::make_unique>()) { + for (auto& [k, v] : other) { + this->insert(k, v); + } + } + // Copy assignment. + NodeMap& operator=(const NodeMap& other) { + if (HasPackage()) { + clear(); + } else { + pkg_ = other.pkg_; + id_ = pkg_ != nullptr ? pkg_->AllocateNodeUserDataId() : -1; + } + for (auto& [k, v] : other) { + this->insert(k, v); + } + return *this; + } + // Move constructor. + NodeMap(NodeMap&& other) { + pkg_ = other.pkg_; + id_ = other.id_; + values_ = std::move(other.values_); + other.id_ = -1; + other.pkg_ = nullptr; + } + // Move assignment. + NodeMap& operator=(NodeMap&& other) { + pkg_ = other.pkg_; + id_ = other.id_; + values_ = std::move(other.values_); + other.pkg_ = nullptr; + other.id_ = -1; + return *this; + } + // Range constructor. + template + NodeMap(It first, It last) + : pkg_(nullptr), + id_(-1), + values_(std::make_unique>()) { + for (auto& [k, v] : iter::imap( + [this](auto& kv) { return std::make_pair(kv.first, kv.second); }, + iter::zip(iter::imap([](auto& kv) { return kv.first; }, first), + iter::imap([](auto& kv) { return kv.second; }, last)))) { + this->insert(k, v); + } + } + // Range constructor with explicit package. + template + NodeMap(Package* pkg, It first, It last) + : pkg_(pkg), + id_(pkg->AllocateNodeUserDataId()), + values_(std::make_unique>()) { + for (auto& [k, v] : iter::imap( + [this](auto& kv) { return std::make_pair(kv.first, kv.second); }, + iter::zip(iter::imap([](auto& kv) { return kv.first; }, first), + iter::imap([](auto& kv) { return kv.second; }, last)))) { + this->insert(k, v); + } + } + // Constructs a `NodeMap` from an `absl::flat_hash_map`. + NodeMap(const absl::flat_hash_map& other) + : pkg_(nullptr), + id_(-1), + values_(std::make_unique>()) { + for (auto& [k, v] : other) { + this->insert(k, v); + } + } + // Assigns contents from an `absl::flat_hash_map`. + NodeMap& operator=(const absl::flat_hash_map& other) { + clear(); + for (auto& [k, v] : other) { + this->insert(k, v); + } + } + // Initializer list constructor. + NodeMap(std::initializer_list init) + : pkg_(nullptr), + id_(-1), + values_(std::make_unique>()) { + for (const auto& pair : init) { + insert(pair.first, pair.second); + } + } + + // Returns true if the map has an associated package. + bool HasPackage() const { return pkg_ != nullptr; } + + // Returns true if the map contains no elements. + bool empty() const { + CheckValidId(); + return values_->empty(); + } + // Returns the number of elements in the map. + size_t size() const { + CheckValidId(); + return values_->size(); + } + + // Returns true if the map contains an element with key `n`. + bool contains(Node* n) const { + CheckValidId(n); + if (!HasPackage()) { + return false; + } + return n->HasUserData(id_); + } + // Returns 1 if the map contains an element with key `n`, 0 otherwise. + size_t count(Node* n) const { + CheckValidId(n); + if (!HasPackage()) { + return 0; + } + return n->HasUserData(id_) ? 1 : 0; + } + // Returns a reference to the value mapped to key `n`. If no such element + // exists, this function CHECK-fails. + ValueT& at(Node* n) { + EnsureValidId(n); + CHECK(contains(n)) << "Nothing was ever set for " << n; + return reinterpret_cast(n->GetUserData(id_))->value.second; + } + // Returns a reference to the value mapped to key `n`, inserting a + // default-constructed value if `n` is not already present. + ValueT& operator[](Node* n) { + EnsureValidId(n); + if (contains(n)) { + return at(n); + } + auto holder = std::make_unique(n); + DataHolder* holder_ptr = holder.get(); + n->SetUserData(id_, EraseType(std::move(holder))); + values_->push_front(holder_ptr); + holder_ptr->iter = values_->begin(); + holder_ptr->configured_list = values_.get(); + return holder_ptr->value.second; + } + // Returns a const reference to the value mapped to key `n`. If no such + // element exists, this function CHECK-fails. + const ValueT& operator[](Node* n) const { return at(n); } + // Returns a const reference to the value mapped to key `n`. If no such + // element exists, this function CHECK-fails. + const ValueT& at(Node* n) const { + CheckValidId(n); + CHECK(contains(n)) << "Nothing was ever set for " << n; + return reinterpret_cast(n->GetUserData(id_))->value.second; + } + + // Erases the element with key `n` if it exists. + void erase(Node* n) { + CheckValidId(n); + if (contains(n)) { + std::optional data = n->TakeUserData(id_); + DCHECK(data); + // The DataHolder could remove itself from the list when its destructor + // runs. It seems better to just be explicit + DataHolder* holder = reinterpret_cast(data->get()); + DCHECK(holder->configured_list == values_.get()); + holder->configured_list = nullptr; + values_->erase(holder->iter); + } + } + + // Erases the element pointed to by `it`. Returns an iterator to the + // element following the erased element. + const_iterator erase(const_iterator it) { + CheckValidId(it->first); + auto res = it; + ++res; + erase(it->first); + return res; + } + + // Erases the element pointed to by `it`. Returns an iterator to the + // element following the erased element. + iterator erase(iterator it) { + CheckValidId(it->first); + auto res = it; + ++res; + erase(it->first); + return res; + } + + // Removes all elements from the map. + ABSL_ATTRIBUTE_REINITIALIZES void clear() { + CheckValidId(); + if (pkg_ == nullptr) { + return; + } + for (DataHolder* v : *values_) { + std::optional data = + v->value.first->TakeUserData(id_); + // We can't remove the current iterator position. + reinterpret_cast(data->get())->configured_list = nullptr; + } + values_->clear(); + // Release the id. + pkg_->ReleaseNodeUserDataId(id_); + } + // Swaps the contents of this map with `other`. + void swap(NodeMap& other) { + std::swap(pkg_, other.pkg_); + std::swap(id_, other.id_); + values_.swap(other.values_); + } + + // Returns an iterator to the first element in the map. + iterator begin() { + CheckValidId(); + return Iterator(values_->begin()); + } + // Returns an iterator to the element following the last element in the map. + iterator end() { + CheckValidId(); + return Iterator(values_->end()); + } + // Returns a const iterator to the first element in the map. + const_iterator cbegin() const { + CheckValidId(); + return ConstIterator(values_->cbegin()); + } + // Returns a const iterator to the element following the last element in the + // map. + const_iterator cend() const { + CheckValidId(); + return ConstIterator(values_->cend()); + } + // Returns a const iterator to the first element in the map. + const_iterator begin() const { + CheckValidId(); + return cbegin(); + } + // Returns a const iterator to the element following the last element in the + // map. + const_iterator end() const { + CheckValidId(); + return cend(); + } + + // Finds an element with key `n`. + // Returns an iterator to the element if found, or `end()` otherwise. + iterator find(Node* n) { + CheckValidId(n); + if (!contains(n)) { + return end(); + } + return Iterator(reinterpret_cast(n->GetUserData(id_))->iter); + } + // Finds an element with key `n`. + // Returns a const iterator to the element if found, or `end()` otherwise. + const_iterator find(Node* n) const { + CheckValidId(n); + if (!contains(n)) { + return end(); + } + return ConstIterator( + reinterpret_cast(n->GetUserData(id_))->iter); + } + + // Inserts a key-value pair into the map if the key does not already exist. + // Returns a pair consisting of an iterator to the inserted element (or to + // the element that prevented the insertion) and a bool denoting whether + // the insertion took place. + std::pair insert(Node* n, ValueT value) { + EnsureValidId(n); + if (contains(n)) { + return std::make_pair(find(n), false); + } + auto holder = std::make_unique(n, std::move(value)); + DataHolder* holder_ptr = holder.get(); + n->SetUserData(id_, EraseType(std::move(holder))); + values_->push_front(holder_ptr); + holder_ptr->iter = values_->begin(); + holder_ptr->configured_list = values_.get(); + return std::make_pair(begin(), true); + } + + // Inserts a key-value pair into the map or assigns to the existing value if + // the key already exists. + // Returns a pair consisting of an iterator to the inserted element (or to + // the element that prevented the insertion) and a bool denoting whether + // the insertion took place. + std::pair insert_or_assign(Node* n, ValueT value) { + EnsureValidId(n); + if (contains(n)) { + Iterator f = find(n); + f->second = std::move(value); + return std::make_pair(f, false); + } + auto holder = std::make_unique(n, std::move(value)); + DataHolder* holder_ptr = holder.get(); + n->SetUserData(id_, EraseType(std::move(holder))); + values_->push_front(holder_ptr); + holder_ptr->iter = values_->begin(); + holder_ptr->configured_list = values_.get(); + return std::make_pair(begin(), true); + } + + // Inserts an element constructed in-place if the key does not already exist. + // Note: Unlike `try_emplace`, `emplace` may construct `ValueT` from `args` + // even if insertion does not occur. + // Returns a pair consisting of an iterator to the inserted element (or to + // the element that prevented the insertion) and a bool denoting whether + // the insertion took place. + template + std::pair emplace(Node* n, Args&&... args) { + EnsureValidId(n); + // If key already exists, construct elements but don't insert. + // This is to match std::map::emplace behavior where element construction + // might happen before check for duplication and value is discarded. + auto holder = std::make_unique(n, std::forward(args)...); + if (contains(n)) { + return std::make_pair(find(n), false); + } + DataHolder* holder_ptr = holder.get(); + n->SetUserData(id_, EraseType(std::move(holder))); + values_->push_front(holder_ptr); + holder_ptr->iter = values_->begin(); + holder_ptr->configured_list = values_.get(); + return std::make_pair(begin(), true); + } + + // Inserts an element constructed in-place if the key does not already exist. + // If the key already exists, no element is constructed. + // Returns a pair consisting of an iterator to the inserted element (or to + // the element that prevented the insertion) and a bool denoting whether + // the insertion took place. + template + std::pair try_emplace(Node* n, Args&&... args) { + EnsureValidId(n); + if (contains(n)) { + return std::make_pair(find(n), false); + } + auto holder = std::make_unique(n, std::forward(args)...); + DataHolder* holder_ptr = holder.get(); + n->SetUserData(id_, EraseType(std::move(holder))); + values_->push_front(holder_ptr); + holder_ptr->iter = values_->begin(); + holder_ptr->configured_list = values_.get(); + return std::make_pair(begin(), true); + } + + friend bool operator==(const Iterator& a, const ConstIterator& b) { + return a.iter_ == b.iter_; + } + friend bool operator==(const ConstIterator& a, const Iterator& b) { + return a.iter_ == b.iter_; + } + friend bool operator!=(const Iterator& a, const ConstIterator& b) { + return a.iter_ != b.iter_; + } + friend bool operator!=(const ConstIterator& a, const Iterator& b) { + return a.iter_ != b.iter_; + } + + private: + void CheckValid() const { +#ifdef DEBUG + CHECK(HasPackage()); + CheckValidId(); +#endif + } + void CheckValidId() const { +#ifdef DEBUG + if (pkg_ != nullptr) { + CHECK(pkg_->IsLiveUserDataId(id_)) << id_; + } +#endif + } + + // Check that this map has a valid id and correct package. + void CheckValidId(Node* n) const { +#ifdef DEBUG + CheckValidId(); + if (HasPackage()) { + CHECK_EQ(n->package(), pkg_) + << "Incorrect package for " << n << " got " << n->package()->name() + << " expected " << pkg_->name(); + } +#endif + } + // Force this map to have a user-data id if it doesn't already + void EnsureValidId(Node* n) { + if (!HasPackage()) { + pkg_ = n->package(); + DCHECK(pkg_ != nullptr) + << "Cannot add a node " << n << " without a package."; + id_ = pkg_->AllocateNodeUserDataId(); + } + CheckValidId(n); + } + + Package* pkg_; + int64_t id_; + std::unique_ptr> values_; +}; + +} // namespace xls + +#endif // XLS_IR_NODE_MAP_H_ diff --git a/xls/ir/node_map_test.cc b/xls/ir/node_map_test.cc new file mode 100644 index 0000000000..5a061cc524 --- /dev/null +++ b/xls/ir/node_map_test.cc @@ -0,0 +1,408 @@ +// Copyright 2025 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/ir/node_map.h" + +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xls/common/status/matchers.h" +#include "xls/common/status/status_macros.h" +#include "xls/ir/benchmark_support.h" +#include "xls/ir/bits.h" +#include "xls/ir/function_builder.h" +#include "xls/ir/ir_matcher.h" +#include "xls/ir/ir_test_base.h" +#include "xls/ir/package.h" + +namespace m = ::xls::op_matchers; + +namespace xls { +namespace { + +using testing::_; +using testing::Eq; +using testing::Pair; +using testing::UnorderedElementsAre; + +template +using TestNodeMap = NodeMap; + +struct EmplaceOnly { + inline static int constructions = 0; + int x; + explicit EmplaceOnly(int i) : x(i) { ++constructions; } + EmplaceOnly(const EmplaceOnly&) = delete; + EmplaceOnly& operator=(const EmplaceOnly&) = delete; + EmplaceOnly(EmplaceOnly&&) = default; + EmplaceOnly& operator=(EmplaceOnly&&) = default; + bool operator==(const EmplaceOnly& other) const { return x == other.x; } +}; + +class NodeMapTest : public IrTestBase {}; + +TEST_F(NodeMapTest, Basic) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + auto c = fb.Add(a, b); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map; + // Set. + map[a.node()] = 1; + map[b.node()] = 2; + + EXPECT_THAT(map.at(a.node()), Eq(1)); + EXPECT_THAT(map.at(b.node()), Eq(2)); + EXPECT_TRUE(map.contains(a.node())); + EXPECT_TRUE(map.contains(b.node())); + EXPECT_FALSE(map.contains(c.node())); + EXPECT_EQ(map.count(a.node()), 1); + EXPECT_EQ(map.count(b.node()), 1); + EXPECT_EQ(map.count(c.node()), 0); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("bar"), Eq(2)))); + + // Update. + map[a.node()] += 5; + + EXPECT_THAT(map.at(a.node()), Eq(6)); + EXPECT_THAT(map.at(b.node()), Eq(2)); + EXPECT_TRUE(map.contains(a.node())); + EXPECT_TRUE(map.contains(b.node())); + EXPECT_FALSE(map.contains(c.node())); + EXPECT_EQ(map.count(a.node()), 1); + EXPECT_EQ(map.count(b.node()), 1); + EXPECT_EQ(map.count(c.node()), 0); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(6)), + Pair(m::Param("bar"), Eq(2)))); + // erase. + map.erase(a.node()); + + EXPECT_THAT(map.at(b.node()), Eq(2)); + EXPECT_TRUE(map.contains(b.node())); + EXPECT_FALSE(map.contains(c.node())); + EXPECT_FALSE(map.contains(a.node())); + EXPECT_EQ(map.count(a.node()), 0); + EXPECT_EQ(map.count(b.node()), 1); + EXPECT_EQ(map.count(c.node()), 0); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("bar"), Eq(2)))); +} + +TEST_F(NodeMapTest, Find) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + auto c = fb.Add(a, b); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map; + // Set. + map[a.node()] = 1; + map[b.node()] = 2; + + EXPECT_THAT(map.find(c.node()), Eq(map.end())); + EXPECT_THAT(map.find(a.node()), + testing::Pointee(Pair(m::Param("foo"), Eq(1)))); + EXPECT_THAT(map.find(b.node()), + testing::Pointee(Pair(m::Param("bar"), Eq(2)))); + // Update with iterator. + map.find(a.node())->second = 33; + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(33)), + Pair(m::Param("bar"), Eq(2)))); +} + +TEST_F(NodeMapTest, Copy) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + fb.Add(a, b); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map; + { + TestNodeMap map1; + // Set. + map1[a.node()] = 1; + map1[b.node()] = 2; + map = map1; + map[a.node()] = 4; + map1[b.node()] = 6; + EXPECT_THAT(map1, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("bar"), Eq(6)))); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(4)), + Pair(m::Param("bar"), Eq(2)))); + } + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(4)), + Pair(m::Param("bar"), Eq(2)))); +} + +TEST_F(NodeMapTest, Move) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + auto c = fb.Add(a, b); + XLS_ASSERT_OK(fb.Build().status()); + + std::optional> opt_map; + { + TestNodeMap map1; + // Set. + map1[a.node()] = 1; + map1[b.node()] = 2; + opt_map.emplace(std::move(map1)); + EXPECT_FALSE(map1.HasPackage()); + } + + TestNodeMap map(*std::move(opt_map)); + + EXPECT_THAT(map.find(c.node()), Eq(map.end())); + EXPECT_THAT(map.find(a.node()), + testing::Pointee(Pair(m::Param("foo"), Eq(1)))); + EXPECT_THAT(map.find(b.node()), + testing::Pointee(Pair(m::Param("bar"), Eq(2)))); +} + +TEST_F(NodeMapTest, Insert) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + auto c = fb.Add(a, b); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map; + // Set. + map[a.node()] = 1; + map[b.node()] = 2; + + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("bar"), Eq(2)))); + EXPECT_THAT(map.insert(c.node(), 3), Pair(_, true)); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("bar"), Eq(2)), + Pair(m::Add(), Eq(3)))); + + EXPECT_THAT(map.insert(a.node(), 3), Pair(map.find(a.node()), false)); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("bar"), Eq(2)), + Pair(m::Add(), Eq(3)))); + + EXPECT_THAT(map.insert_or_assign(a.node(), 3), + Pair(map.find(a.node()), false)); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(3)), + Pair(m::Param("bar"), Eq(2)), + Pair(m::Add(), Eq(3)))); + + map.erase(c.node()); + EXPECT_THAT(map.insert_or_assign(c.node(), 7), Pair(_, true)); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(3)), + Pair(m::Param("bar"), Eq(2)), + Pair(m::Add(), Eq(7)))); +} + +TEST_F(NodeMapTest, Emplace) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map; + EmplaceOnly::constructions = 0; + auto [it, inserted] = map.emplace(a.node(), 42); + EXPECT_TRUE(inserted); + EXPECT_TRUE(map.contains(a.node())); + EXPECT_EQ(map.at(a.node()).x, 42); + EXPECT_EQ(it->second.x, 42); + EXPECT_EQ(EmplaceOnly::constructions, 1); + // Emplace on existing fails but constructs argument. + auto [it2, inserted2] = map.emplace(a.node(), 44); + EXPECT_FALSE(inserted2); + EXPECT_EQ(map.at(a.node()).x, 42); + EXPECT_EQ(it2->second.x, 42); + EXPECT_EQ(it2, it); + EXPECT_EQ(EmplaceOnly::constructions, 2); +} + +TEST_F(NodeMapTest, TryEmplace) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map; + EmplaceOnly::constructions = 0; + auto [it, inserted] = map.try_emplace(a.node(), 42); + EXPECT_TRUE(inserted); + EXPECT_TRUE(map.contains(a.node())); + EXPECT_EQ(map.at(a.node()).x, 42); + EXPECT_EQ(it->second.x, 42); + EXPECT_EQ(EmplaceOnly::constructions, 1); + // try_emplace on existing fails and does not construct argument. + auto [it2, inserted2] = map.try_emplace(a.node(), 44); + EXPECT_FALSE(inserted2); + EXPECT_EQ(map.at(a.node()).x, 42); + EXPECT_EQ(it2->second.x, 42); + EXPECT_EQ(it2, it); + EXPECT_EQ(EmplaceOnly::constructions, 1); +} + +TEST_F(NodeMapTest, Swap) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map1; + TestNodeMap map2; + map1[a.node()] = 1; + map2[b.node()] = 2; + map1.swap(map2); + EXPECT_THAT(map1, UnorderedElementsAre(Pair(m::Param("bar"), Eq(2)))); + EXPECT_THAT(map2, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)))); +} + +TEST_F(NodeMapTest, InitializerList) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + XLS_ASSERT_OK(fb.Build().status()); + + TestNodeMap map{{a.node(), 1}, {b.node(), 2}}; + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("bar"), Eq(2)))); +} + +TEST_F(NodeMapTest, NodeDeletionRemovesMapElement) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto a = fb.Param("foo", p->GetBitsType(32)); + auto b = fb.Param("bar", p->GetBitsType(32)); + auto c = fb.Param("baz", p->GetBitsType(32)); + // Remove doesn't like getting rid of the last node. + fb.Param("other", p->GetBitsType(32)); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + TestNodeMap map{{a.node(), 1}, {b.node(), 2}, {c.node(), 3}}; + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("bar"), Eq(2)), + Pair(m::Param("baz"), Eq(3)))); + XLS_ASSERT_OK(f->RemoveNode(b.node())); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("foo"), Eq(1)), + Pair(m::Param("baz"), Eq(3)))); + XLS_ASSERT_OK(f->RemoveNode(a.node())); + EXPECT_THAT(map, UnorderedElementsAre(Pair(m::Param("baz"), Eq(3)))); + XLS_ASSERT_OK(f->RemoveNode(c.node())); + EXPECT_THAT(map, UnorderedElementsAre()); +} + +absl::Status GenerateFunction(Package* p, benchmark::State& state) { + FunctionBuilder fb("benchmark", p); + XLS_RETURN_IF_ERROR( + benchmark_support::GenerateChain( + fb, state.range(0), 2, benchmark_support::strategy::BinaryAdd(), + benchmark_support::strategy::SharedLiteral(UBits(32, 32))) + .status()); + return fb.Build().status(); +} + +template +void BM_ReadSome(benchmark::State& v, Setup setup) { + Package p("benchmark"); + XLS_ASSERT_OK(GenerateFunction(&p, v)); + Function* f = p.functions()[0].get(); + for (auto s : v) { + Map map; + setup(map); + int64_t i = 0; + for (Node* n : f->nodes()) { + if (i++ % 4 == 0) { + map[n].value = i; + } + } + for (int64_t i = 0; i < v.range(1); ++i) { + for (Node* n : f->nodes()) { + auto v = map.find(n); + if (v != map.end()) { + benchmark::DoNotOptimize(v->second.value); + } + benchmark::DoNotOptimize(v); + } + } + for (Node* n : f->nodes()) { + if (i++ % 3 == 0) { + map.erase(n); + } + if (i++ % 7 == 0) { + map[n].value = i; + } + } + for (int64_t i = 0; i < v.range(1); ++i) { + for (Node* n : f->nodes()) { + auto v = map.find(n); + if (v != map.end()) { + benchmark::DoNotOptimize(v->second.value); + } + benchmark::DoNotOptimize(v); + } + } + } +} + +// Simulate a typical xls map value which has real destructors etc. +struct TestValue { + int64_t value; + TestValue() ABSL_ATTRIBUTE_NOINLINE : value(12) {} + explicit TestValue(int64_t v) ABSL_ATTRIBUTE_NOINLINE : value(v) { + benchmark::DoNotOptimize(v); + } + TestValue(const TestValue& v) ABSL_ATTRIBUTE_NOINLINE : value(v.value) { + benchmark::DoNotOptimize(v); + } + ~TestValue() ABSL_ATTRIBUTE_NOINLINE { benchmark::DoNotOptimize(value); } + TestValue(TestValue&& v) ABSL_ATTRIBUTE_NOINLINE : value(v.value) { + benchmark::DoNotOptimize(v); + } +}; +void BM_ReadSomeNodeMap(benchmark::State& v) { + BM_ReadSome>(v, [](auto& a) {}); +} +void BM_ReadSomeFlatMap(benchmark::State& v) { + BM_ReadSome>(v, [](auto& a) {}); +} +void BM_ReadSomeFlatMapReserve(benchmark::State& v) { + BM_ReadSome>( + v, [&v](auto& map) { map.reserve(v.range(0)); }); +} +BENCHMARK(BM_ReadSomeNodeMap)->RangePair(100, 100000, 1, 100); +BENCHMARK(BM_ReadSomeFlatMap)->RangePair(100, 100000, 1, 100); +BENCHMARK(BM_ReadSomeFlatMapReserve)->RangePair(100, 100000, 1, 100); + +} // namespace +} // namespace xls diff --git a/xls/ir/package.cc b/xls/ir/package.cc index 3d665f4ab1..452415c582 100644 --- a/xls/ir/package.cc +++ b/xls/ir/package.cc @@ -29,6 +29,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -37,6 +38,7 @@ #include "absl/types/span.h" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" +#include "xls/data_structures/inline_bitmap.h" #include "xls/ir/block.h" #include "xls/ir/call_graph.h" #include "xls/ir/channel.h" @@ -56,7 +58,7 @@ namespace xls { -Package::Package(std::string_view name) : name_(name) {} +Package::Package(std::string_view name) : name_(name), user_data_ids_(64) {} Package::~Package() = default; @@ -950,6 +952,50 @@ TransformMetricsProto TransformMetrics::ToProto() const { return ret; } +namespace { +#ifdef NDEBUG +static constexpr bool kDebugMode = false; +#else +static constexpr bool kDebugMode = true; +#endif +} // namespace + +void Package::ReleaseNodeUserDataId(int64_t id) { + CHECK(user_data_ids_.Get(id)) << "id: " << id; + user_data_ids_.Set(id, false); + if constexpr (kDebugMode) { + for (FunctionBase* fb : GetFunctionBases()) { + for (Node* n : fb->nodes()) { + CHECK(!n->HasUserData(id)) + << "id: " << id << " node: " << n->ToString(); + } + } + } +} +int64_t Package::AllocateNodeUserDataId() { + if (user_data_ids_.IsAllOnes()) { + int64_t size = user_data_ids_.bit_count(); + user_data_ids_ = std::move(user_data_ids_).WithSize(size + 64); + LOG(WARNING) << "Excessive user data live use: " << (size + 1) + << " live data!"; + } + int64_t off = 0; + // Find first word with a false bit. + while (!(~user_data_ids_.GetWord(off / InlineBitmap::kWordBits))) { + off += InlineBitmap::kWordBits; + } + // Find the byte. + while (!(~user_data_ids_.GetByte(off / 8))) { + off += 8; + } + // Find the bit. + while (user_data_ids_.Get(off)) { + off++; + } + user_data_ids_.Set(off); + return off; +} + // Printers for fuzztest use. namespace { void WriteParseFunction(const Package& p, std::ostream* os) { diff --git a/xls/ir/package.h b/xls/ir/package.h index ddede3db2d..02a32e519a 100644 --- a/xls/ir/package.h +++ b/xls/ir/package.h @@ -16,11 +16,13 @@ #define XLS_IR_PACKAGE_H_ #include +#include #include #include #include #include #include +#include #include #include "absl/container/flat_hash_map.h" @@ -28,6 +30,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xls/data_structures/inline_bitmap.h" #include "xls/ir/channel.h" #include "xls/ir/channel_ops.h" #include "xls/ir/fileno.h" @@ -448,6 +451,21 @@ class Package { } TransformMetrics& transform_metrics() { return transform_metrics_; } + // Allocate a new user data id. This function will not reuse an id until + // ReleaseNodeUserDataId is called on it. + int64_t AllocateNodeUserDataId(); + + // Releases the user data id and allows it to be reused. + // + // NB This must be called once for each value returned by + // AllocateNodeUserDataId. + // + // When this is called all nodes with user data *MUST* have *already* had + // TakeUserData called on them to delete the user data associated with them. + // On DEBUG builds this is CHECKed. + void ReleaseNodeUserDataId(int64_t id); + bool IsLiveUserDataId(int64_t id) { return user_data_ids_.Get(id); } + private: std::vector GetChannelNames() const; @@ -493,6 +511,9 @@ class Package { // Metrics which record the total number of transformations to the package. TransformMetrics transform_metrics_ = {0}; + + // Bitmap containing allocated user data ids. + InlineBitmap user_data_ids_; }; // Printers for fuzztest use.