5
5
import torch
6
6
import torch_xla .core .xla_model as xm
7
7
8
- from vllm .v1 .sample .metadata import SamplingMetadata
8
+ from vllm .v1 .worker .gpu_input_batch import InputBatch
9
+
10
+ DEFAULT_SAMPLING_PARAMS = dict (
11
+ temperature = - 1.0 ,
12
+ min_p = 0.0 ,
13
+ # strictly disabled for now
14
+ # top_k=-1,
15
+ # top_p=0.0,
16
+ # frequency_penalties=0.0,
17
+ # presence_penalties=0.0,
18
+ # repetition_penalties=0.0,
19
+ )
9
20
10
21
11
22
@dataclass
@@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata:
20
31
top_k : torch .Tensor = None
21
32
top_p : torch .Tensor = None
22
33
23
- # XLA-unfriendly control flow in Sampler
24
- all_greedy : bool = False
25
- all_random : bool = False
26
34
# Greedy sampling flag for compiling single xla graph.
27
- do_argmax : torch .Tensor = None
28
-
29
- # speculation not supported
30
- spec_token_ids = None
35
+ all_greedy : torch .Tensor = None
31
36
32
37
# Generator not supported by xla
33
38
generators : dict [int ,
@@ -54,106 +59,68 @@ class TPUSupportedSamplingMetadata:
54
59
bad_words_token_ids = None
55
60
indices_do_sample : torch .Tensor = None
56
61
57
- def __post_init__ (self ):
58
- temp = self .temperature
59
- if self .indices_do_sample is None :
60
- self .indices_do_sample = torch .zeros (temp .shape [0 ],
61
- device = temp .device ,
62
- dtype = torch .int32 )
63
- if self .do_argmax is None :
64
- self .do_argmax = torch .tensor (0 ,
65
- dtype = torch .bool ,
66
- device = temp .device )
67
-
68
62
@classmethod
69
- def from_sampling_metadata (
70
- cls , metadata : SamplingMetadata ,
71
- padded_do_sample_indices : torch .Tensor , num_do_sample : int ,
72
- device : torch .device ) -> "TPUSupportedSamplingMetadata" :
63
+ def from_input_batch (
64
+ cls , input_batch : InputBatch ,
65
+ indices_do_sample : torch .Tensor ) -> "TPUSupportedSamplingMetadata" :
73
66
"""
74
- Create an XLA-frienly SamplingMetadata structure. Do so by first
75
- instantiating an object with fixed-sized tensors and then writing the
76
- values in input `metadata`. Do that only for non-None values so that
77
- recompilation is not triggered for optional values (None/torch.Tensor).
78
-
79
- In order to handle different sizes for the params that range from 1 up
80
- to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
81
- Same thing for `padded_do_sample_indices`, which contains the indices
82
- to be fed to the Sampler, padded to the closest pre-compiled shape.
83
-
84
- Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
85
- do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
67
+ Copy sampling tensors slices from `input_batch` to on device tensors.
68
+
69
+ `InputBatch._make_sampling_metadata` causes recompilation on XLA as it
70
+ slices dynamic shapes on device tensors. This impl moves the dynamic
71
+ ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
72
+ also reuses the on-device persistent tensors managed in `input_batch`
73
+ to reduce waste.
74
+
75
+ `indices_do_sample` contains the indices to be fed to the Sampler,
76
+ normally one per request, here padded to the closest pre-compiled shape
77
+ We expect sampling params tensors to be padded to the same fixed shape.
78
+
79
+ Eg. 3 requests, tensors padded to 4
80
+ temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
81
+ sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
86
82
"""
87
- metadata = cls ._validate_sampling_metadata (metadata )
88
- # NOTE we have to initialize default tensor-based params first and
89
- # skip None values altogether to produce the same xla graph.
90
- num_samples = len (padded_do_sample_indices )
91
- do_argmax = torch .tensor (metadata .all_greedy ,
92
- dtype = torch .bool ,
93
- device = device )
94
- new_metadata = cls .get_default_sampling_params (num_samples , device ,
95
- indices_do_sample = \
96
- padded_do_sample_indices ,
97
- do_argmax = do_argmax
98
- )
99
- supported_params = \
100
- TPUSupportedSamplingMetadata ._get_default_params_values ()
101
- # Copy input non-None values into `new_metadata` fixed-sized tensors.
102
- for p_name in supported_params :
103
- old_val = getattr (metadata , p_name )
104
- new_val = getattr (new_metadata , p_name )
105
- if isinstance (old_val , torch .Tensor ):
106
- new_val [:num_do_sample ] = old_val
107
- setattr (new_metadata , p_name , new_val )
83
+ num_reqs = input_batch .num_reqs
84
+ padded_num_reqs = len (indices_do_sample )
85
+
86
+ def copy_slice (cpu_tensor : torch .Tensor , tpu_tensor : torch .Tensor ,
87
+ fill_val ) -> torch .Tensor :
88
+ # Copy slice from CPU to corresponding TPU pre-allocated tensor.
89
+ # Pad value is the default one.
90
+ cpu_tensor [num_reqs :padded_num_reqs ] = fill_val
91
+ tpu_tensor [:padded_num_reqs ] = cpu_tensor [:padded_num_reqs ]
92
+
93
+ # NOTE NickLucche The sync CPU-TPU graph we produce here must be
94
+ # consistent. We can't have flags to skip copies or we'll end up
95
+ # recompiling.
96
+ copy_slice (input_batch .temperature_cpu_tensor , input_batch .temperature ,
97
+ DEFAULT_SAMPLING_PARAMS ["temperature" ])
98
+ # TODO Temporarily disabled until sampling options are enabled
99
+ # copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
100
+ # copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
101
+ copy_slice (input_batch .min_p_cpu_tensor , input_batch .min_p ,
102
+ DEFAULT_SAMPLING_PARAMS ["min_p" ])
103
+
104
+ # copy_slice(input_batch.frequency_penalties_cpu_tensor,
105
+ # input_batch.frequency_penalties)
106
+ # copy_slice(input_batch.presence_penalties_cpu_tensor,
107
+ # input_batch.presence_penalties)
108
+ # copy_slice(input_batch.repetition_penalties_cpu_tensor,
109
+ # input_batch.repetition_penalties)
108
110
109
111
xm .mark_step ()
110
112
xm .wait_device_ops ()
111
- return new_metadata
112
113
113
- @classmethod
114
- def get_default_sampling_params (
115
- cls ,
116
- num_samples : int ,
117
- device : torch .device ,
118
- indices_do_sample = None ,
119
- do_argmax = None ) -> "TPUSupportedSamplingMetadata" :
120
- # As sampling happens on a single traced graph, options
121
- # are "disabled" by having them evaluate to an Identity op.
122
- # Note that initialization is dependent on num_samples.
123
- sampling_metadata_disable_value = \
124
- TPUSupportedSamplingMetadata ._get_default_params_values ()
125
- init_kwargs = dict ()
126
- for p_name , (default_val ,
127
- dtype ) in sampling_metadata_disable_value .items ():
128
- default_tensor = torch .full ((num_samples , ),
129
- default_val ,
130
- dtype = dtype ,
131
- device = device )
132
- init_kwargs [p_name ] = default_tensor
133
-
134
- return cls (** init_kwargs ,
135
- indices_do_sample = indices_do_sample ,
136
- do_argmax = do_argmax )
137
-
138
- @staticmethod
139
- def _validate_sampling_metadata (
140
- sampling_metadata : SamplingMetadata ) -> SamplingMetadata :
141
- if sampling_metadata .all_greedy :
142
- # Set to None since #13587. Make sure default isn't overruled.
143
- assert sampling_metadata .temperature is None
144
- return sampling_metadata
145
-
146
- @staticmethod
147
- def _get_default_params_values ():
148
- return dict (
149
- # Since #13587 greedy sampling requires branching off which leads
150
- # to separate graphs. We set temp to noop and handle argmax here.
151
- temperature = (1.0 , torch .float32 ),
152
- min_p = (0.0 , torch .float32 ),
153
- # strictly disabled for now
154
- # top_k=(-1, torch.int32),
155
- # top_p=(0.0, torch.float32),
156
- # frequency_penalties=(0.0, torch.float32),
157
- # presence_penalties=(0.0, torch.float32),
158
- # repetition_penalties=(0.0, torch.float32),
159
- )
114
+ # Slice persistent device tensors to a fixed pre-compiled padded shape.
115
+ return cls (
116
+ temperature = input_batch .temperature [:padded_num_reqs ],
117
+ # Scalar tensor for xla-friendly tracing.
118
+ all_greedy = torch .tensor (input_batch .all_greedy ,
119
+ dtype = torch .bool ,
120
+ device = input_batch .device ),
121
+ # TODO enable more and avoid returning None values
122
+ top_p = None , # input_batch.top_p[:padded_num_reqs],
123
+ top_k = None , # input_batch.top_k[:padded_num_reqs],
124
+ min_p = input_batch .min_p [:padded_num_reqs ],
125
+ generators = input_batch .generators ,
126
+ indices_do_sample = indices_do_sample )
0 commit comments