@@ -86,6 +86,126 @@ def randn_tensor(
8686 return latents
8787
8888
89+ def rand_tensor (
90+ shape : Union [Tuple , List ],
91+ generator : Optional [Union [List ["torch.Generator" ], "torch.Generator" ]] = None ,
92+ device : Optional [Union [str , "torch.device" ]] = None ,
93+ dtype : Optional ["torch.dtype" ] = None ,
94+ layout : Optional ["torch.layout" ] = None ,
95+ ):
96+ """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
97+ passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
98+ is always created on the CPU. This is analogous to `randn_tensor`, except it creates random tensors from the
99+ uniform distribution over [0, 1] using `torch.rand`.
100+ """
101+ # device on which tensor is created defaults to device
102+ if isinstance (device , str ):
103+ device = torch .device (device )
104+ rand_device = device
105+ batch_size = shape [0 ]
106+
107+ layout = layout or torch .strided
108+ device = device or torch .device ("cpu" )
109+
110+ if generator is not None :
111+ gen_device_type = generator .device .type if not isinstance (generator , list ) else generator [0 ].device .type
112+ if gen_device_type != device .type and gen_device_type == "cpu" :
113+ rand_device = "cpu"
114+ if device != "mps" :
115+ logger .info (
116+ f"The passed generator was created on 'cpu' even though a tensor on { device } was expected."
117+ f" Tensors will be created on 'cpu' and then moved to { device } . Note that one can probably"
118+ f" slightly speed up this function by passing a generator that was created on the { device } device."
119+ )
120+ elif gen_device_type != device .type and gen_device_type == "cuda" :
121+ raise ValueError (f"Cannot generate a { device } tensor from a generator of type { gen_device_type } ." )
122+
123+ # make sure generator list of length 1 is treated like a non-list
124+ if isinstance (generator , list ) and len (generator ) == 1 :
125+ generator = generator [0 ]
126+
127+ if isinstance (generator , list ):
128+ shape = (1 ,) + shape [1 :]
129+ latents = [
130+ torch .rand (shape , generator = generator [i ], device = rand_device , dtype = dtype , layout = layout )
131+ for i in range (batch_size )
132+ ]
133+ latents = torch .cat (latents , dim = 0 ).to (device )
134+ else :
135+ latents = torch .rand (shape , generator = generator , device = rand_device , dtype = dtype , layout = layout ).to (device )
136+
137+ return latents
138+
139+
140+ def multinomial_tensor (
141+ logits : torch .Tensor ,
142+ num_samples : int ,
143+ replacement : bool = False ,
144+ generator : Optional [Union [List ["torch.Generator" ], "torch.Generator" ]] = None ,
145+ device : Optional [Union [str , "torch.device" ]] = None ,
146+ squeeze_trailing_dim : bool = True ,
147+ ):
148+ """
149+ Creates a tensor drawn from the multinomial distribution specified by the (possibly unnormalized) probabilities
150+ given by `logits`. This is to analogous to `randn_tensor`, wrapping `torch.multinomial` rather than `torch.randn`.
151+
152+ In general, if `logits` has shape [..., num_categories], where the ... represents leading batch dimensions, the
153+ output will have shape [..., num_samples]. `logits` is assumed to have at least one leading batch dimension.
154+ """
155+ batch_size = logits .shape [0 ]
156+ num_cats = logits .shape [- 1 ]
157+
158+ device = device or torch .device ("cpu" )
159+
160+ if generator is not None :
161+ gen_device = generator .device if not isinstance (generator , list ) else generator [0 ].device
162+ gen_device_type = gen_device .type
163+ if gen_device_type != device .type and gen_device_type == "cpu" :
164+ if device != "mps" :
165+ logger .info (
166+ f"The passed generator was created on 'cpu' even though a tensor on { device } was expected."
167+ f" Tensors will be created on 'cpu' and then moved to { device } . Note that one can probably"
168+ f" slightly speed up this function by passing a generator that was created on the { device } device."
169+ )
170+ elif gen_device_type != device .type and gen_device_type == "cuda" :
171+ raise ValueError (f"Cannot generate a { device } tensor from a generator of type { gen_device_type } ." )
172+
173+ # make sure generator list of length 1 is treated like a non-list
174+ if isinstance (generator , list ) and len (generator ) == 1 :
175+ generator = generator [0 ]
176+
177+ # Handle the case where generator is on CPU
178+ logits_ = logits .to (gen_device ) if generator is not None else logits
179+
180+ # Multinomial is not implemented for half precision on CPU
181+ if logits_ .device .type == "cpu" and logits_ .dtype != torch .float32 :
182+ logits_ = logits_ .float ()
183+
184+ if isinstance (generator , list ):
185+ sample = []
186+ original_shape = logits .shape [1 :- 1 ]
187+ for i in range (batch_size ):
188+ logits_instance = logits_ [i ]
189+ if logits_instance .ndim > 2 :
190+ logits_instance = logits_instance .reshape (- 1 , num_cats )
191+ sample_instance = torch .multinomial (logits_instance , num_samples , replacement , generator = generator [i ])
192+ if logits_instance .ndim > 2 :
193+ sample_instance = sample_instance .view (* original_shape , num_samples )
194+ sample = torch .stack (sample , dim = 0 ).to (device )
195+ else :
196+ if logits .ndim > 2 :
197+ original_shape = logits .shape [:- 1 ]
198+ logits_ = logits_ .reshape (- 1 , logits .size (- 1 ))
199+ sample = torch .multinomial (logits_ , num_samples , replacement , generator = generator ).to (device )
200+ if logits .ndim > 2 :
201+ sample = sample .view (* original_shape , num_samples )
202+
203+ if squeeze_trailing_dim :
204+ sample = sample .squeeze (- 1 )
205+
206+ return sample
207+
208+
89209def is_compiled_module (module ) -> bool :
90210 """Check whether the module was compiled with torch.compile()"""
91211 if is_torch_version ("<" , "2.0.0" ) or not hasattr (torch , "_dynamo" ):
0 commit comments