1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick Eagle test script for faster verification during development.
4
+
5
+ This script runs a minimal test to verify Eagle models are working without
6
+ full model initialization overhead.
7
+ """
8
+
9
+ def test_config_loading ():
10
+ """Test that Eagle configs can be loaded properly."""
11
+ print ("Testing Eagle config loading..." )
12
+
13
+ try :
14
+ from vllm .transformers_utils .configs .speculators_eagle import (
15
+ SpeculatorsEagleConfig ,
16
+ is_speculators_eagle_config
17
+ )
18
+
19
+ # Test speculators detection
20
+ is_speculators = is_speculators_eagle_config ("nm-testing/eagle3-llama3.1-8b-instruct-speculators" )
21
+ print (f"✓ Eagle-3 speculators detection: { is_speculators } " )
22
+
23
+ # Test regular Eagle detection
24
+ is_regular = is_speculators_eagle_config ("nm-testing/eagle-llama3.1-8b-instruct" )
25
+ print (f"✓ Regular Eagle detection (should be False): { is_regular } " )
26
+
27
+ # Try loading a speculators config
28
+ if is_speculators :
29
+ config = SpeculatorsEagleConfig .from_pretrained (
30
+ "nm-testing/eagle3-llama3.1-8b-instruct-speculators"
31
+ )
32
+ print (f"✓ Config loaded successfully" )
33
+ print (f" - Method: { getattr (config , 'method' , 'N/A' )} " )
34
+ print (f" - Num lookahead tokens: { getattr (config , 'num_lookahead_tokens' , 'N/A' )} " )
35
+ print (f" - Model type: { getattr (config , 'model_type' , 'N/A' )} " )
36
+
37
+ return True
38
+
39
+ except Exception as e :
40
+ print (f"✗ Config test failed: { str (e )} " )
41
+ return False
42
+
43
+ def test_model_imports ():
44
+ """Test that Eagle model classes can be imported."""
45
+ print ("\n Testing Eagle model imports..." )
46
+
47
+ try :
48
+ # Test V1 Eagle model import
49
+ from vllm .model_executor .models .llama_eagle import EagleLlamaForCausalLM
50
+ print ("✓ V1 Eagle model imported successfully" )
51
+
52
+ # Test V0 Eagle model import
53
+ from vllm .model_executor .models .eagle import EAGLEModel
54
+ print ("✓ V0 Eagle model imported successfully" )
55
+
56
+ # Test detection utilities
57
+ from vllm .engine .arg_utils import EngineArgs
58
+ print ("✓ Engine args imported successfully" )
59
+
60
+ return True
61
+
62
+ except Exception as e :
63
+ print (f"✗ Import test failed: { str (e )} " )
64
+ return False
65
+
66
+ def test_engine_args ():
67
+ """Test that speculative config can be created."""
68
+ print ("\n Testing engine argument handling..." )
69
+
70
+ try :
71
+ from vllm .engine .arg_utils import EngineArgs
72
+
73
+ # Test creating engine args with Eagle-3 speculative config
74
+ args = EngineArgs (
75
+ model = "meta-llama/Meta-Llama-3.1-8B-Instruct" ,
76
+ speculative_config = {
77
+ "method" : "eagle" ,
78
+ "model" : "nm-testing/eagle3-llama3.1-8b-instruct-speculators" ,
79
+ "num_spec_tokens" : 5
80
+ }
81
+ )
82
+
83
+ print ("✓ EngineArgs created successfully" )
84
+
85
+ # Test speculative config creation
86
+ spec_config = args .create_speculative_config (
87
+ args .speculative_config ,
88
+ model_config = None # We're just testing creation
89
+ )
90
+
91
+ if spec_config :
92
+ print ("✓ Speculative config created successfully" )
93
+ print (f" - Method: { spec_config .method } " )
94
+ print (f" - Draft model: { spec_config .model } " )
95
+ print (f" - Spec tokens: { spec_config .num_spec_tokens } " )
96
+
97
+ return True
98
+
99
+ except Exception as e :
100
+ print (f"✗ Engine args test failed: { str (e )} " )
101
+ return False
102
+
103
+ def main ():
104
+ """Run quick tests."""
105
+ print ("Running Quick Eagle Verification Tests" )
106
+ print ("=" * 50 )
107
+
108
+ tests = [
109
+ test_config_loading ,
110
+ test_model_imports ,
111
+ test_engine_args ,
112
+ ]
113
+
114
+ passed = 0
115
+ for test in tests :
116
+ if test ():
117
+ passed += 1
118
+
119
+ print (f"\n { '=' * 50 } " )
120
+ print (f"Quick Tests Summary: { passed } /{ len (tests )} passed" )
121
+
122
+ if passed == len (tests ):
123
+ print ("🎉 All quick tests passed! Eagle support is working." )
124
+ else :
125
+ print ("⚠️ Some tests failed. Check the output above." )
126
+
127
+ return 0 if passed == len (tests ) else 1
128
+
129
+ if __name__ == "__main__" :
130
+ exit (main ())
0 commit comments