From ad06f072dda969b04a17535873e89c8eda9a8cd5 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Wed, 11 Jun 2025 14:50:08 -0700 Subject: [PATCH 1/3] fix-sharded-weights-typeerror --- keras_hub/src/utils/preset_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 918c3e1869..c397c625cb 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -732,7 +732,13 @@ def _get_sharded_filenames(self, config_path): with open(config_path, encoding="utf-8") as config_file: config = json.load(config_file) weight_map = config["weight_map"] - return sorted(set(weight_map.values())) + filenames = set() + for v in weight_map.values(): + if isinstance(v, list): + filenames.update(v) + else: + filenames.add(v) + return sorted(filenames) def _load_backbone_weights(self, backbone): # Detect if the backbone is sharded or not. From 51843dcceef150d24912c7d00f9ead7f5352b80a Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Thu, 12 Jun 2025 14:40:14 -0700 Subject: [PATCH 2/3] Add unit test case --- keras_hub/src/utils/preset_utils_test.py | 40 ++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/keras_hub/src/utils/preset_utils_test.py b/keras_hub/src/utils/preset_utils_test.py index 5e6888c965..9a038abc2f 100644 --- a/keras_hub/src/utils/preset_utils_test.py +++ b/keras_hub/src/utils/preset_utils_test.py @@ -59,6 +59,46 @@ def test_sharded_weights(self): ): self.assertAllClose(v1, v2) + @pytest.mark.large + def test_sharded_weights_config(self): + if not sharded_weights_available(): + self.skipTest("Sharded weights are not available.") + + init_kwargs = { + "vocabulary_size": 1024, + "num_layers": 12, + "num_query_heads": 8, + "num_key_value_heads": 4, + "hidden_dim": 32, + "intermediate_dim": 64, + "head_dim": 4, + "sliding_window_size": 5, + "attention_logit_soft_cap": 50, + "final_logit_soft_cap": 30, + "layer_norm_epsilon": 1e-6, + "query_head_dim_normalize": False, + "use_post_ffw_norm": True, + "use_post_attention_norm": True, + "use_sliding_window_attention": True, + } + backbone = GemmaBackbone(**init_kwargs) # ~422KB + + # Save the sharded weights. + preset_dir = self.get_temp_dir() + backbone.save_to_preset(preset_dir, max_shard_size=0.0002) + config_file = os.path.join(preset_dir, "model.weights.json") + self.assertTrue(os.path.exists(config_file)) + config = json.load(open(config_file, "r")) + weight_map = config["weight_map"] + # check the format of `weight_map` + self.assertIsInstance(weight_map, dict) + for key, value in weight_map.items(): + self.assertIsInstance(key, str) + if isinstance(value, list): + self.assertTrue(all(isinstance(item, str) for item in value)) + else: + self.assertIsInstance(value, str) + @pytest.mark.large def test_disabled_sharding(self): init_kwargs = { From 7a74ee1b32226e0e423ea840cbe9489e4ffb0fe4 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Fri, 13 Jun 2025 14:51:55 -0700 Subject: [PATCH 3/3] Moved common initialization code to setUp for cleaner and Updated all three relevant test cases to use the shared setup. --- keras_hub/src/utils/preset_utils_test.py | 57 ++++++------------------ 1 file changed, 14 insertions(+), 43 deletions(-) diff --git a/keras_hub/src/utils/preset_utils_test.py b/keras_hub/src/utils/preset_utils_test.py index 9a038abc2f..ed55e63885 100644 --- a/keras_hub/src/utils/preset_utils_test.py +++ b/keras_hub/src/utils/preset_utils_test.py @@ -18,12 +18,8 @@ class PresetUtilsTest(TestCase): - @pytest.mark.large - def test_sharded_weights(self): - if not sharded_weights_available(): - self.skipTest("Sharded weights are not available.") - - init_kwargs = { + def setUp(self): + self.init_kwargs = { "vocabulary_size": 1024, "num_layers": 12, "num_query_heads": 8, @@ -40,7 +36,14 @@ def test_sharded_weights(self): "use_post_attention_norm": True, "use_sliding_window_attention": True, } - backbone = GemmaBackbone(**init_kwargs) # ~422KB + + @pytest.mark.large + def test_sharded_weights(self): + if not sharded_weights_available(): + self.skipTest("Sharded weights are not available.") + + # Create a GemmaBackbone instance with the specified init kwargs + backbone = GemmaBackbone(**self.init_kwargs) # ~422KB # Save the sharded weights. preset_dir = self.get_temp_dir() @@ -64,24 +67,8 @@ def test_sharded_weights_config(self): if not sharded_weights_available(): self.skipTest("Sharded weights are not available.") - init_kwargs = { - "vocabulary_size": 1024, - "num_layers": 12, - "num_query_heads": 8, - "num_key_value_heads": 4, - "hidden_dim": 32, - "intermediate_dim": 64, - "head_dim": 4, - "sliding_window_size": 5, - "attention_logit_soft_cap": 50, - "final_logit_soft_cap": 30, - "layer_norm_epsilon": 1e-6, - "query_head_dim_normalize": False, - "use_post_ffw_norm": True, - "use_post_attention_norm": True, - "use_sliding_window_attention": True, - } - backbone = GemmaBackbone(**init_kwargs) # ~422KB + # Create a GemmaBackbone instance with the specified init kwargs + backbone = GemmaBackbone(**self.init_kwargs) # ~422KB # Save the sharded weights. preset_dir = self.get_temp_dir() @@ -101,24 +88,8 @@ def test_sharded_weights_config(self): @pytest.mark.large def test_disabled_sharding(self): - init_kwargs = { - "vocabulary_size": 1024, - "num_layers": 12, - "num_query_heads": 8, - "num_key_value_heads": 4, - "hidden_dim": 32, - "intermediate_dim": 64, - "head_dim": 4, - "sliding_window_size": 5, - "attention_logit_soft_cap": 50, - "final_logit_soft_cap": 30, - "layer_norm_epsilon": 1e-6, - "query_head_dim_normalize": False, - "use_post_ffw_norm": True, - "use_post_attention_norm": True, - "use_sliding_window_attention": True, - } - backbone = GemmaBackbone(**init_kwargs) + # Create a GemmaBackbone instance with the specified init kwargs + backbone = GemmaBackbone(**self.init_kwargs) # Save the weights with `max_shard_size=None` preset_dir = self.get_temp_dir()