@@ -457,7 +457,12 @@ def time_per_output_token_ms(self) -> Optional[float]: # type: ignore[override]
457457 This includes the time to generate the first token and all other tokens.
458458 None if the output_tokens is None or 0.
459459 """
460- if self .output_tokens is None or self .output_tokens == 0 :
460+ if (
461+ self .output_tokens is None
462+ or self .output_tokens == 0
463+ or self .first_token_time is None
464+ or self .last_token_time is None
465+ ):
461466 return None
462467
463468 return super ().time_per_output_token_ms
@@ -614,41 +619,46 @@ def duration(self) -> float:
614619 ),
615620 )
616621
617- def create_sampled (self , sample_size : int ) -> "GenerativeBenchmark" :
622+ def set_sample_size (self , sample_size : Optional [ int ] ) -> "GenerativeBenchmark" :
618623 """
619- Create a new benchmark instance with a random sample of the completed and
620- errored requests based on the given sample sizes. If the sample sizes are
621- larger than the total number of requests, the sample sizes are capped at
622- the total number of requests.
624+ Set the sample size for the benchmark. This will randomly sample the
625+ requests for each status type to the given sample size or the maximum
626+ number of requests for that status type, whichever is smaller.
627+ This is applied to requests.successful, requests.errored, and
628+ requests.incomplete.
629+ If None, no sampling is applied and the state is kept.
623630
624631 :param sample_size: The number of requests to sample for each status type.
625- :return: A new benchmark instance with the sampled requests.
626- :raises ValueError: If the sample sizes are negative .
632+ :return: The benchmark with the sampled requests.
633+ :raises ValueError: If the sample size is invalid .
627634 """
628- if sample_size < 0 :
629- raise ValueError (f"Sample size must be non-negative, given { sample_size } " )
630635
631- sample_size = min (sample_size , len (self .requests .successful ))
632- error_sample_size = min (sample_size , len (self .requests .errored ))
633- incomplete_sample_size = min (sample_size , len (self .requests .incomplete ))
636+ if sample_size is not None :
637+ if sample_size < 0 or not isinstance (sample_size , int ):
638+ raise ValueError (
639+ f"Sample size must be non-negative integer, given { sample_size } "
640+ )
634641
635- sampled_instance = self .model_copy ()
636- sampled_instance .requests .successful = random .sample (
637- self .requests .successful , sample_size
638- )
639- sampled_instance .requests .errored = random .sample (
640- self .requests .errored , error_sample_size
641- )
642- sampled_instance .requests .incomplete = random .sample (
643- self .requests .incomplete , incomplete_sample_size
644- )
645- sampled_instance .request_samples = StatusBreakdown (
646- successful = len (sampled_instance .requests .successful ),
647- incomplete = len (sampled_instance .requests .incomplete ),
648- errored = len (sampled_instance .requests .errored ),
649- )
642+ sample_size = min (sample_size , len (self .requests .successful ))
643+ error_sample_size = min (sample_size , len (self .requests .errored ))
644+ incomplete_sample_size = min (sample_size , len (self .requests .incomplete ))
645+
646+ self .requests .successful = random .sample (
647+ self .requests .successful , sample_size
648+ )
649+ self .requests .errored = random .sample (
650+ self .requests .errored , error_sample_size
651+ )
652+ self .requests .incomplete = random .sample (
653+ self .requests .incomplete , incomplete_sample_size
654+ )
655+ self .request_samples = StatusBreakdown (
656+ successful = len (self .requests .successful ),
657+ incomplete = len (self .requests .incomplete ),
658+ errored = len (self .requests .errored ),
659+ )
650660
651- return sampled_instance
661+ return self
652662
653663 @staticmethod
654664 def from_stats (
0 commit comments