From 1fa186fe1e813e604eff1fa23377a21f07cee802 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Fri, 18 Jul 2025 13:17:05 -0400 Subject: [PATCH 01/20] Add turns support to synthetic dataset Signed-off-by: Samuel Monson --- src/guidellm/dataset/synthetic.py | 103 ++++++++++++++++++++---------- 1 file changed, 71 insertions(+), 32 deletions(-) diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py index 8c30f0f76..06972643b 100644 --- a/src/guidellm/dataset/synthetic.py +++ b/src/guidellm/dataset/synthetic.py @@ -3,7 +3,7 @@ from collections.abc import Iterable, Iterator from itertools import cycle from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, Optional, TypedDict, Union import yaml from datasets import ( @@ -69,6 +69,26 @@ class SyntheticDatasetConfig(BaseModel): gt=0, default=None, ) + turns: int = Field( + description="The number of turns in the conversation.", + gt=0, + default=1, + ) + turns_stdev: Optional[int] = Field( + description="The standard deviation of the number of turns.", + gt=0, + default=None, + ) + turns_min: Optional[int] = Field( + description="The minimum number of turns in the conversation.", + gt=0, + default=None, + ) + turns_max: Optional[int] = Field( + description="The maximum number of turns in the conversation.", + gt=0, + default=None, + ) samples: int = Field( description="The number of samples to generate for the dataset.", gt=0, @@ -124,14 +144,13 @@ def parse_config_file(data: Union[str, Path]) -> "SyntheticDatasetConfig": return SyntheticDatasetConfig(**config_dict) -class SyntheticTextItemsGenerator( - Iterable[ - dict[ - Literal["prompt", "prompt_tokens_count", "output_tokens_count"], - Union[str, int], - ] - ] -): +class SyntheticDatasetRow(TypedDict): + prompt: list[str] + prompt_tokens_count: list[int] + output_tokens_count: list[int] + + +class SyntheticTextItemsGenerator(Iterable[SyntheticDatasetRow]): def __init__( self, config: SyntheticDatasetConfig, @@ -147,12 +166,7 @@ def __init__( def __iter__( self, - ) -> Iterator[ - dict[ - Literal["prompt", "prompt_tokens_count", "output_tokens_count"], - Union[str, int], - ] - ]: + ) -> Iterator[SyntheticDatasetRow]: prompt_tokens_sampler = IntegerRangeSampler( average=self.config.prompt_tokens, variance=self.config.prompt_tokens_stdev, @@ -167,6 +181,13 @@ def __iter__( max_value=self.config.output_tokens_max, random_seed=self.random_seed + 1, # ensure diff dist from prompts ) + turns_sampler = IntegerRangeSampler( + average=self.config.turns, + variance=self.config.turns_stdev, + min_value=self.config.turns_min, + max_value=self.config.turns_max, + random_seed=self.random_seed + 7, # ensure diff dist + ) # ensure diff distribution from output tokens rand = random.Random(self.random_seed + 2) # noqa: S311 unique_prefix_iter = cycle(self.processor.get_vocab().values()) @@ -174,24 +195,42 @@ def __iter__( prefix_index = rand.randint(0, len(self.text_creator.words)) prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index) - for _, prompt_tokens, output_tokens in zip( - range(self.config.samples), - prompt_tokens_sampler, - output_tokens_sampler, - ): - start_index = rand.randint(0, len(self.text_creator.words)) - prompt_text = self.processor.decode( - prefix_tokens - + self._create_prompt( - prompt_tokens, start_index, next(unique_prefix_iter) - ), - skip_special_tokens=True, - ) - yield { - "prompt": prompt_text, - "prompt_tokens_count": self.config.prefix_tokens + prompt_tokens, - "output_tokens_count": output_tokens, + for _, turns in zip(range(self.config.samples), turns_sampler): + row: SyntheticDatasetRow = { + "prompt": [], + "prompt_tokens_count": [], + "output_tokens_count": [], } + for i, prompt_tokens, output_tokens in zip( + range(turns), + prompt_tokens_sampler, + output_tokens_sampler, + ): + start_index = rand.randint(0, len(self.text_creator.words)) + # Append the prefix tokens only for the first turn + if i == 0: + prompt_text = self.processor.decode( + prefix_tokens + + self._create_prompt( + prompt_tokens, start_index, next(unique_prefix_iter) + ), + skip_special_tokens=True, + ) + row["prompt"].append(prompt_text) + row["prompt_tokens_count"].append(self.config.prefix_tokens + prompt_tokens) + row["output_tokens_count"].append(output_tokens) + else: + prompt_text = self.processor.decode( + self._create_prompt( + prompt_tokens, start_index, next(unique_prefix_iter) + ), + skip_special_tokens=True, + ) + row["prompt"].append(prompt_text) + row["prompt_tokens_count"].append(prompt_tokens) + row["output_tokens_count"].append(output_tokens) + + yield row def _create_prompt( self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None From 7efb7b174336881be06a298d9c214eec650c4d1a Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Tue, 23 Sep 2025 15:59:05 -0400 Subject: [PATCH 02/20] Add basic multiturn loader support Signed-off-by: Samuel Monson --- src/guidellm/request/loader.py | 47 +++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py index 607a74554..e23e31117 100644 --- a/src/guidellm/request/loader.py +++ b/src/guidellm/request/loader.py @@ -105,14 +105,14 @@ def __init__( self.preserve_iter_state = iter_type == "infinite" # ensure no caching requests self._preserved_iter = None - def __iter__(self) -> Iterator[GenerationRequest]: + def __iter__(self) -> Iterator[list[GenerationRequest]]: scope_create_count = 0 while (dataset_iter := self._get_dataset_iter(scope_create_count)) is not None: scope_create_count += 1 for item in dataset_iter: - yield self._create_request(item) + yield self._create_requests(item) self._preserved_iter = None @@ -260,25 +260,36 @@ def _get_dataset_iter( return dataset_iter - def _create_request(self, item: dict[str, Any]) -> GenerationRequest: - prompt_tokens = ( - item[self.column_mappings["prompt_tokens_count_column"]] + def _create_requests(self, item: dict[str, Any]) -> list[GenerationRequest]: + prompts = list(item[self.column_mappings["prompt_column"]]) + prompts_tokens: list[Optional[int]] = ( + list(item[self.column_mappings["prompt_tokens_count_column"]]) if "prompt_tokens_count_column" in self.column_mappings - else None + else [None] * len(prompts) ) - output_tokens = ( - item[self.column_mappings["output_tokens_count_column"]] + outputs_tokens: list[Optional[int]] = ( + list(item[self.column_mappings["output_tokens_count_column"]]) if "output_tokens_count_column" in self.column_mappings - else None + else [None] * len(prompts) ) - return GenerationRequest( - request_type=settings.preferred_route, - content=item[self.column_mappings["prompt_column"]], - stats=( + if len(prompts) != len(prompts_tokens) != len(outputs_tokens): + raise ValueError( + "Mismatched lengths between prompts and token counts. " + f"Prompts: {len(prompts)}, Prompt Tokens: {len(prompts_tokens)}, " + f"Output Tokens: {len(outputs_tokens)}" + ) + + return [ + GenerationRequest( + request_type=settings.preferred_route, + content=prompt, + stats=( {"prompt_tokens": prompt_tokens} if prompt_tokens is not None else {} - ), - constraints=( - {"output_tokens": output_tokens} if output_tokens is not None else {} - ), - ) + ), + constraints=( + {"output_tokens": output_tokens} if output_tokens is not None else {} + ), + ) + for prompt, prompt_tokens, output_tokens in zip(prompts, prompts_tokens, outputs_tokens) + ] From 3f0cdbc1f5594d74d91bf0443cf652fd851f1182 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Thu, 25 Sep 2025 15:42:05 -0400 Subject: [PATCH 03/20] Make dict encoding recursive Signed-off-by: Samuel Monson --- src/guidellm/utils/encoding.py | 37 ++++++++-------------------------- 1 file changed, 8 insertions(+), 29 deletions(-) diff --git a/src/guidellm/utils/encoding.py b/src/guidellm/utils/encoding.py index ccd26982b..d4fa007b2 100644 --- a/src/guidellm/utils/encoding.py +++ b/src/guidellm/utils/encoding.py @@ -390,23 +390,11 @@ def to_dict(self, obj: Any) -> Any: if isinstance(obj, BaseModel): return self.to_dict_pydantic(obj) - if isinstance(obj, (list, tuple)) and any( - isinstance(item, BaseModel) for item in obj - ): - return [ - self.to_dict_pydantic(item) if isinstance(item, BaseModel) else item - for item in obj - ] + if isinstance(obj, (list, tuple)): + return [self.to_dict(item) for item in obj] - if isinstance(obj, dict) and any( - isinstance(value, BaseModel) for value in obj.values() - ): - return { - key: self.to_dict_pydantic(value) - if isinstance(value, BaseModel) - else value - for key, value in obj.items() - } + if isinstance(obj, dict): + return {key: self.to_dict(value) for key, value in obj.items()} return obj @@ -418,22 +406,13 @@ def from_dict(self, data: Any) -> Any: :return: Reconstructed object with proper types restored """ if isinstance(data, (list, tuple)): - return [ - self.from_dict_pydantic(item) - if isinstance(item, dict) and "*PYD*" in item - else item - for item in data - ] - elif isinstance(data, dict) and data: + return [self.from_dict(item) for item in data] + + if isinstance(data, dict) and data: if "*PYD*" in data: return self.from_dict_pydantic(data) - return { - key: self.from_dict_pydantic(value) - if isinstance(value, dict) and "*PYD*" in value - else value - for key, value in data.items() - } + return {key: self.from_dict(value) for key, value in data.items()} return data From 220377e31b3c9f9b14a33e991e6ebfc45fb88c08 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Thu, 25 Sep 2025 15:54:13 -0400 Subject: [PATCH 04/20] Use details for next request in chain Signed-off-by: Samuel Monson --- src/guidellm/scheduler/worker_group.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/guidellm/scheduler/worker_group.py b/src/guidellm/scheduler/worker_group.py index c1d516f19..355ca86b5 100644 --- a/src/guidellm/scheduler/worker_group.py +++ b/src/guidellm/scheduler/worker_group.py @@ -496,7 +496,11 @@ def _iter(): count = 0 request_info: ScheduledRequestInfo = None - for request in _iter(): + for request_chain in _iter(): + if isinstance(request_chain, (list, tuple)): + request = request_chain[0] + else: + request = request_chain count += 1 if hasattr(request, "request_id"): From 3ac4df61a555d3b195f6bcf93c15ffe0a8f3d17d Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Fri, 26 Sep 2025 12:10:21 -0400 Subject: [PATCH 05/20] Implement worker support for multiturn Signed-off-by: Samuel Monson --- src/guidellm/request/loader.py | 33 +++++++++----- src/guidellm/scheduler/__init__.py | 8 ++++ src/guidellm/scheduler/objects.py | 44 +++++++++++++++--- src/guidellm/scheduler/worker.py | 25 ++++++++-- src/guidellm/scheduler/worker_group.py | 63 +++++++++++++++----------- 5 files changed, 124 insertions(+), 49 deletions(-) diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py index e23e31117..81aae8fb9 100644 --- a/src/guidellm/request/loader.py +++ b/src/guidellm/request/loader.py @@ -105,7 +105,7 @@ def __init__( self.preserve_iter_state = iter_type == "infinite" # ensure no caching requests self._preserved_iter = None - def __iter__(self) -> Iterator[list[GenerationRequest]]: + def __iter__(self) -> Iterator[list[tuple[GenerationRequest, float]]]: scope_create_count = 0 while (dataset_iter := self._get_dataset_iter(scope_create_count)) is not None: @@ -260,7 +260,9 @@ def _get_dataset_iter( return dataset_iter - def _create_requests(self, item: dict[str, Any]) -> list[GenerationRequest]: + def _create_requests( + self, item: dict[str, Any] + ) -> list[tuple[GenerationRequest, float]]: prompts = list(item[self.column_mappings["prompt_column"]]) prompts_tokens: list[Optional[int]] = ( list(item[self.column_mappings["prompt_tokens_count_column"]]) @@ -281,15 +283,24 @@ def _create_requests(self, item: dict[str, Any]) -> list[GenerationRequest]: ) return [ - GenerationRequest( - request_type=settings.preferred_route, - content=prompt, - stats=( - {"prompt_tokens": prompt_tokens} if prompt_tokens is not None else {} - ), - constraints=( - {"output_tokens": output_tokens} if output_tokens is not None else {} + ( + GenerationRequest( + request_type=settings.preferred_route, + content=prompt, + stats=( + {"prompt_tokens": prompt_tokens} + if prompt_tokens is not None + else {} + ), + constraints=( + {"output_tokens": output_tokens} + if output_tokens is not None + else {} + ), ), + 0.0, # TODO: delay + ) + for prompt, prompt_tokens, output_tokens in zip( + prompts, prompts_tokens, outputs_tokens ) - for prompt, prompt_tokens, output_tokens in zip(prompts, prompts_tokens, outputs_tokens) ] diff --git a/src/guidellm/scheduler/__init__.py b/src/guidellm/scheduler/__init__.py index 646474241..cb225460a 100644 --- a/src/guidellm/scheduler/__init__.py +++ b/src/guidellm/scheduler/__init__.py @@ -15,16 +15,20 @@ from .objects import ( BackendInterface, BackendT, + HistoryT, MeasuredRequestTimings, MultiTurnRequestT, + MultiTurnT, RequestSchedulerTimings, RequestT, ResponseT, + ScheduledRequestAugmentation, ScheduledRequestInfo, SchedulerMessagingPydanticRegistry, SchedulerState, SchedulerUpdateAction, SchedulerUpdateActionProgress, + TurnT, ) from .scheduler import Scheduler from .strategies import ( @@ -56,6 +60,7 @@ "ConstraintInitializer", "ConstraintsInitializerFactory", "Environment", + "HistoryT", "LastCompletionRequestTimings", "MaxDurationConstraint", "MaxErrorRateConstraint", @@ -64,6 +69,7 @@ "MaxNumberConstraint", "MeasuredRequestTimings", "MultiTurnRequestT", + "MultiTurnT", "NoDelayRequestTimings", "NonDistributedEnvironment", "PoissonRateRequestTimings", @@ -71,6 +77,7 @@ "RequestSchedulerTimings", "RequestT", "ResponseT", + "ScheduledRequestAugmentation", "ScheduledRequestInfo", "ScheduledRequestTimings", "Scheduler", @@ -84,6 +91,7 @@ "StrategyType", "SynchronousStrategy", "ThroughputStrategy", + "TurnT", "UnserializableConstraintInitializer", "WorkerProcess", "WorkerProcessGroup", diff --git a/src/guidellm/scheduler/objects.py b/src/guidellm/scheduler/objects.py index b7f2efc34..a58d9225e 100644 --- a/src/guidellm/scheduler/objects.py +++ b/src/guidellm/scheduler/objects.py @@ -19,7 +19,6 @@ Literal, Protocol, TypeVar, - Union, ) from pydantic import Field, computed_field @@ -35,34 +34,50 @@ __all__ = [ "BackendInterface", "BackendT", + "HistoryT", "MeasuredRequestTimings", "MultiTurnRequestT", + "MultiTurnT", "RequestSchedulerTimings", "RequestT", "ResponseT", + "ScheduledRequestAugmentation", "ScheduledRequestInfo", "SchedulerMessagingPydanticRegistry", "SchedulerState", "SchedulerUpdateAction", "SchedulerUpdateActionProgress", + "TurnT", ] RequestT = TypeVar("RequestT") """Generic request object type for scheduler processing.""" +# TODO: Remove +MultiTurnRequestT = RequestT + ResponseT = TypeVar("ResponseT") """Generic response object type returned by backend processing.""" -MultiTurnRequestT = TypeAliasType( - "MultiTurnRequestT", - Union[ - list[Union[RequestT, tuple[RequestT, float]]], - tuple[Union[RequestT, tuple[RequestT, float]]], - ], +TurnT = TypeAliasType( + "TurnT", + tuple[RequestT, "ScheduledRequestAugmentation", "ScheduledRequestInfo"], + type_params=(RequestT,), +) + +MultiTurnT = TypeAliasType( + "MultiTurnT", + list[TurnT[RequestT]], type_params=(RequestT,), ) """Multi-turn request structure supporting conversation history with optional delays.""" +HistoryT = TypeAliasType( + "HistoryT", + list[tuple[RequestT, ResponseT]], + type_params=(RequestT, ResponseT), +) + class SchedulerMessagingPydanticRegistry(RegistryMixin[RegistryObjT]): """ @@ -71,6 +86,21 @@ class SchedulerMessagingPydanticRegistry(RegistryMixin[RegistryObjT]): """ +@SchedulerMessagingPydanticRegistry.register() +class ScheduledRequestAugmentation(StandardBaseModel): + """ + Adjustments to scheduler logic for a paired request. + """ + + post_requeue_delay: float = Field( + description=( + "Delay in seconds to wait after a request to " + "queue the next request in the conversation." + ), + default=0.0, + ) + + @SchedulerMessagingPydanticRegistry.register() class RequestSchedulerTimings(StandardBaseModel): """ diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 5f2fb74bb..4513fe3a5 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -31,9 +31,12 @@ from guidellm.scheduler.objects import ( BackendInterface, + HistoryT, MultiTurnRequestT, + MultiTurnT, RequestT, ResponseT, + ScheduledRequestAugmentation, ScheduledRequestInfo, SchedulerMessagingPydanticRegistry, ) @@ -118,6 +121,9 @@ def __init__( self.startup_completed = False self.backend_started = False self.messaging_started = False + self.turns_queue: list[ + tuple[HistoryT[RequestT, ResponseT], MultiTurnT[RequestT]] + ] = [] def run(self): """ @@ -302,16 +308,19 @@ async def _cancel_requests_loop(self): self._send_update("cancelled", None, request, request_info) async def _process_next_request(self): - request: RequestT | MultiTurnRequestT[RequestT] | None = None + request: RequestT | None = None request_info: ScheduledRequestInfo | None = None response: ResponseT | None = None + aug: ScheduledRequestAugmentation | None = None try: # Pull request from the queue - request, request_info = await self.messaging.get() - - if isinstance(request, (list, tuple)): - raise NotImplementedError("Multi-turn requests are not yet supported") + history, conversation = ( + self.turns_queue.pop(0) + if self.turns_queue + else ([], await self.messaging.get()) + ) + request, aug, request_info = conversation.pop(0) # Calculate targeted start and set pending state for request request_info.scheduler_node_id = self.messaging.worker_index @@ -341,6 +350,12 @@ async def _process_next_request(self): request_info.scheduler_timings.resolve_end = time.time() self._send_update("completed", response, request, request_info) + # If multi-turn, queue up next turn(s) + # TODO: Move to callback and support delay + if conversation: # more turns to process + history.append((request, response)) + self.turns_queue.append((history, conversation)) + response = request = request_info = None except asyncio.CancelledError: # Handle cancellation diff --git a/src/guidellm/scheduler/worker_group.py b/src/guidellm/scheduler/worker_group.py index 355ca86b5..221e95e1e 100644 --- a/src/guidellm/scheduler/worker_group.py +++ b/src/guidellm/scheduler/worker_group.py @@ -26,8 +26,10 @@ from guidellm.scheduler.objects import ( BackendInterface, MultiTurnRequestT, + MultiTurnT, RequestT, ResponseT, + ScheduledRequestAugmentation, ScheduledRequestInfo, SchedulerMessagingPydanticRegistry, SchedulerState, @@ -471,9 +473,9 @@ def __init__( def requests_generator( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, - cycle_requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, - ) -> Generator[tuple[RequestT | MultiTurnRequestT[RequestT],], None, None]: + requests: Iterable[Iterable[tuple[RequestT, float]]] | None, + cycle_requests: Iterable[Iterable[tuple[RequestT, float]]] | None, + ) -> Generator[MultiTurnT[RequestT], None, None]: """ Generate request-info pairs for worker processing with constraint evaluation. @@ -494,31 +496,40 @@ def _iter(): while True: yield from cycle_requests - count = 0 - request_info: ScheduledRequestInfo = None + count: int = 0 + stop_queueing: bool = False + + def _turn_iter(requests_chain: Iterable[tuple[RequestT, float]]): + nonlocal count, stop_queueing + for request, delay in requests_chain: + count += 1 + + if hasattr(request, "request_id"): + request_id = request.request_id + elif hasattr(request, "id"): + request_id = request.id + else: + request_id = str(uuid.uuid4()) + request_augmentation = ScheduledRequestAugmentation( + post_requeue_delay=delay + ) + request_info: ScheduledRequestInfo = ScheduledRequestInfo( + request_id=request_id, + status="queued", + scheduler_process_id=0, + scheduler_start_time=self.start_time, + ) + state_update = self._locked_update(request_info) + yield (request, request_augmentation, request_info) + + if state_update.stop_queueing: + stop_queueing = True + return + for request_chain in _iter(): - if isinstance(request_chain, (list, tuple)): - request = request_chain[0] - else: - request = request_chain - count += 1 - - if hasattr(request, "request_id"): - request_id = request.request_id - elif hasattr(request, "id"): - request_id = request.id - else: - request_id = str(uuid.uuid4()) - request_info: ScheduledRequestInfo = ScheduledRequestInfo( - request_id=request_id, - status="queued", - scheduler_process_id=0, - scheduler_start_time=self.start_time, - ) - state_update = self._locked_update(request_info) - yield (request, request_info) + yield list(_turn_iter(request_chain)) - if state_update.stop_queueing: + if stop_queueing: self.stop_send_requests_event.set() return From a7bf6900fc77125bc3d887701c8a4e89856b89c2 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Fri, 26 Sep 2025 12:55:42 -0400 Subject: [PATCH 06/20] Cancel requests in conversation Signed-off-by: Samuel Monson --- src/guidellm/scheduler/worker.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 4513fe3a5..155552a8d 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -296,16 +296,22 @@ async def _cancel_requests_loop(self): try: request: RequestT request_info: ScheduledRequestInfo - request, request_info = await self.messaging.get( - timeout=self.messaging.poll_interval + _, conversation = ( + self.turns_queue.pop(0) + if self.turns_queue + else ( + None, + await self.messaging.get(timeout=self.messaging.poll_interval), + ) ) except asyncio.TimeoutError: continue - request_info.scheduler_node_id = self.messaging.worker_index - request_info.error = "Request was cancelled" - request_info.scheduler_timings.resolve_end = time.time() - self._send_update("cancelled", None, request, request_info) + for request, _, request_info in conversation: + request_info.scheduler_node_id = self.messaging.worker_index + request_info.error = "Request was cancelled" + request_info.scheduler_timings.resolve_end = time.time() + self._send_update("cancelled", None, request, request_info) async def _process_next_request(self): request: RequestT | None = None From e276f6c091a9c995d85644dfbbb32ec24a344d33 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Fri, 26 Sep 2025 14:38:11 -0400 Subject: [PATCH 07/20] Cancel whole conversation Signed-off-by: Samuel Monson --- src/guidellm/scheduler/worker.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 155552a8d..33be659fa 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -314,6 +314,7 @@ async def _cancel_requests_loop(self): self._send_update("cancelled", None, request, request_info) async def _process_next_request(self): + conversation: MultiTurnT[RequestT] | None = None request: RequestT | None = None request_info: ScheduledRequestInfo | None = None response: ResponseT | None = None @@ -362,7 +363,7 @@ async def _process_next_request(self): history.append((request, response)) self.turns_queue.append((history, conversation)) - response = request = request_info = None + response = request = request_info = conversation = None except asyncio.CancelledError: # Handle cancellation if request is not None and request_info is not None: @@ -375,6 +376,12 @@ async def _process_next_request(self): request_info.error = str(exc) request_info.scheduler_timings.resolve_end = time.time() self._send_update("errored", response, request, request_info) + finally: + if conversation is not None: + for request, _, request_info in conversation: + request_info.error = "Request was cancelled" + request_info.scheduler_timings.resolve_end = time.time() + self._send_update("cancelled", response, request, request_info) def _send_update( self, From 1de1c64a072b00512c97b364d62474d5ff436226 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Fri, 26 Sep 2025 14:57:52 -0400 Subject: [PATCH 08/20] Implement multiturn history in openai backend Signed-off-by: Samuel Monson --- src/guidellm/backends/openai.py | 27 +++++++++++++++++++++------ src/guidellm/scheduler/worker.py | 4 +++- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/guidellm/backends/openai.py b/src/guidellm/backends/openai.py index ce83076fc..acce5f88d 100644 --- a/src/guidellm/backends/openai.py +++ b/src/guidellm/backends/openai.py @@ -16,6 +16,7 @@ import json import time from collections.abc import AsyncIterator +from itertools import chain from pathlib import Path from typing import Any, ClassVar, Optional, Union @@ -29,7 +30,7 @@ GenerationRequestTimings, GenerationResponse, ) -from guidellm.scheduler import ScheduledRequestInfo +from guidellm.scheduler import HistoryT, ScheduledRequestInfo __all__ = ["OpenAIHTTPBackend", "UsageStats"] @@ -280,7 +281,7 @@ async def resolve( self, request: GenerationRequest, request_info: ScheduledRequestInfo, - history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None, + history: Optional[HistoryT[GenerationRequest, GenerationResponse]] = None, ) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]: """ Process a generation request and yield progressive responses. @@ -295,10 +296,8 @@ async def resolve( :yields: Tuples of (response, updated_request_info) as generation progresses. """ self._check_in_process() - if history is not None: - raise NotImplementedError( - "Multi-turn requests with conversation history are not yet supported" - ) + if history: + request = self._apply_history(request, history) response = GenerationResponse( request_id=request.request_id, @@ -500,6 +499,22 @@ async def chat_completions( self._get_completions_usage_stats(data), ) + def _apply_history( + self, + request: GenerationRequest, + history: HistoryT[GenerationRequest, GenerationResponse], + ) -> GenerationRequest: + """ + Apply conversation history to the current request. + """ + + def turn_to_text(turn: tuple[GenerationRequest, GenerationResponse]) -> str: + req, res = turn + return f"{req.content}{res.value}" + + request.content = "".join(chain(map(turn_to_text, history), (request.content,))) + return request + def _build_headers( self, api_key: Optional[str], diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 33be659fa..3c980e60c 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -349,7 +349,9 @@ async def _process_next_request(self): # Process the request with the backend request_info.scheduler_timings.resolve_start = time.time() self._send_update("in_progress", response, request, request_info) - async for resp, info in self.backend.resolve(request, request_info, None): + async for resp, info in self.backend.resolve( + request, request_info, history + ): response = resp request_info = info From 0e8713c776958f230171414fbe271e52ef63198f Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Fri, 26 Sep 2025 15:49:53 -0400 Subject: [PATCH 09/20] Add wait_then_requeue behavior Signed-off-by: Samuel Monson --- src/guidellm/scheduler/worker.py | 62 +++++++++++++++++++++++++------- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 3c980e60c..81b9e2b1d 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -13,7 +13,7 @@ import time from multiprocessing.synchronize import Barrier as ProcessingBarrier from multiprocessing.synchronize import Event as ProcessingEvent -from typing import Annotated, Generic, Literal +from typing import Annotated, Generic, Literal, TypeAliasType try: import uvloop @@ -50,6 +50,16 @@ __all__ = ["WorkerProcess"] +ProcessRequestT = TypeAliasType( + "ProcessRequestT", + tuple[ + HistoryT[RequestT, ResponseT], + MultiTurnT[RequestT], + ScheduledRequestAugmentation, + ], + type_params=(RequestT, ResponseT), +) + class WorkerProcess(Generic[RequestT, ResponseT]): """ @@ -271,12 +281,20 @@ async def _process_requests_loop(self): async_semaphore = asyncio.Semaphore(self.async_limit) pending_tasks: set[asyncio.Task] = set() - def _task_done(task): + def _task_done(task: asyncio.Task[ProcessRequestT[RequestT, ResponseT]]): pending_tasks.discard(task) async_semaphore.release() - if not task.cancelled() and (exception := task.exception()): - raise exception + if not task.cancelled(): + if exception := task.exception(): + raise exception + + history, conversation, aug = task.result() + if conversation: + requeue_task = asyncio.create_task( + self._wait_then_requeue(history, conversation, aug) + ) + pending_tasks.add(requeue_task) # Main loop; loop until canceled while True: @@ -313,12 +331,14 @@ async def _cancel_requests_loop(self): request_info.scheduler_timings.resolve_end = time.time() self._send_update("cancelled", None, request, request_info) - async def _process_next_request(self): - conversation: MultiTurnT[RequestT] | None = None + async def _process_next_request(self) -> ProcessRequestT[RequestT, ResponseT]: + conversation: MultiTurnT[RequestT] = [] + history: HistoryT[RequestT, ResponseT] = [] request: RequestT | None = None request_info: ScheduledRequestInfo | None = None response: ResponseT | None = None aug: ScheduledRequestAugmentation | None = None + premature_exit: bool = False try: # Pull request from the queue @@ -359,14 +379,12 @@ async def _process_next_request(self): request_info.scheduler_timings.resolve_end = time.time() self._send_update("completed", response, request, request_info) - # If multi-turn, queue up next turn(s) - # TODO: Move to callback and support delay - if conversation: # more turns to process - history.append((request, response)) - self.turns_queue.append((history, conversation)) + # Record Turn + history.append((request, response)) - response = request = request_info = conversation = None + response = request = request_info = None except asyncio.CancelledError: + premature_exit = True # Handle cancellation if request is not None and request_info is not None: request_info.error = "Request was cancelled" @@ -374,17 +392,35 @@ async def _process_next_request(self): self._send_update("cancelled", response, request, request_info) raise except Exception as exc: # noqa: BLE001 + premature_exit = True if request is not None and request_info is not None: request_info.error = str(exc) request_info.scheduler_timings.resolve_end = time.time() self._send_update("errored", response, request, request_info) finally: - if conversation is not None: + if premature_exit and conversation: for request, _, request_info in conversation: request_info.error = "Request was cancelled" request_info.scheduler_timings.resolve_end = time.time() self._send_update("cancelled", response, request, request_info) + return history, conversation, aug + + async def _wait_then_requeue( + self, + history: HistoryT[RequestT, ResponseT], + conversation: MultiTurnT[RequestT], + aug: ScheduledRequestAugmentation, + ): + try: + if aug.post_requeue_delay > 0: + await asyncio.sleep(aug.post_requeue_delay) + except asyncio.CancelledError: + # If we are cancelled, dump straight to queue + raise + finally: + self.turns_queue.append((history, conversation)) + def _send_update( self, new_status: Literal[ From cd43b2cf768d21201ec50e17c8025de54e78474b Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Fri, 26 Sep 2025 16:55:18 -0400 Subject: [PATCH 10/20] Type cleanup Signed-off-by: Samuel Monson --- src/guidellm/scheduler/__init__.py | 10 ++++----- src/guidellm/scheduler/environments.py | 16 +++++++------- src/guidellm/scheduler/objects.py | 28 +++++++++++------------- src/guidellm/scheduler/scheduler.py | 6 +++--- src/guidellm/scheduler/worker.py | 20 +++++++---------- src/guidellm/scheduler/worker_group.py | 30 ++++++++++---------------- tests/unit/scheduler/test_objects.py | 16 -------------- 7 files changed, 46 insertions(+), 80 deletions(-) diff --git a/src/guidellm/scheduler/__init__.py b/src/guidellm/scheduler/__init__.py index cb225460a..4eff5c125 100644 --- a/src/guidellm/scheduler/__init__.py +++ b/src/guidellm/scheduler/__init__.py @@ -15,10 +15,10 @@ from .objects import ( BackendInterface, BackendT, + DatasetIterT, HistoryT, MeasuredRequestTimings, - MultiTurnRequestT, - MultiTurnT, + RequestDataT, RequestSchedulerTimings, RequestT, ResponseT, @@ -28,7 +28,6 @@ SchedulerState, SchedulerUpdateAction, SchedulerUpdateActionProgress, - TurnT, ) from .scheduler import Scheduler from .strategies import ( @@ -59,6 +58,7 @@ "Constraint", "ConstraintInitializer", "ConstraintsInitializerFactory", + "DatasetIterT", "Environment", "HistoryT", "LastCompletionRequestTimings", @@ -68,12 +68,11 @@ "MaxGlobalErrorRateConstraint", "MaxNumberConstraint", "MeasuredRequestTimings", - "MultiTurnRequestT", - "MultiTurnT", "NoDelayRequestTimings", "NonDistributedEnvironment", "PoissonRateRequestTimings", "PydanticConstraintInitializer", + "RequestDataT", "RequestSchedulerTimings", "RequestT", "ResponseT", @@ -91,7 +90,6 @@ "StrategyType", "SynchronousStrategy", "ThroughputStrategy", - "TurnT", "UnserializableConstraintInitializer", "WorkerProcess", "WorkerProcessGroup", diff --git a/src/guidellm/scheduler/environments.py b/src/guidellm/scheduler/environments.py index 6234f8f67..a98535445 100644 --- a/src/guidellm/scheduler/environments.py +++ b/src/guidellm/scheduler/environments.py @@ -19,14 +19,14 @@ import time from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Iterable +from collections.abc import AsyncIterator from typing import ( Generic, ) from guidellm.scheduler.constraints import Constraint from guidellm.scheduler.objects import ( - MultiTurnRequestT, + DatasetIterT, RequestT, ResponseT, ScheduledRequestInfo, @@ -52,11 +52,11 @@ class Environment(ABC, Generic[RequestT, ResponseT], InfoMixin): @abstractmethod async def sync_run_params( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], + requests: DatasetIterT[RequestT], strategy: SchedulingStrategy, constraints: dict[str, Constraint], ) -> tuple[ - Iterable[RequestT | MultiTurnRequestT[RequestT]], + DatasetIterT[RequestT], SchedulingStrategy, dict[str, Constraint], ]: @@ -130,7 +130,7 @@ async def sync_run_end( ) -> AsyncIterator[ tuple[ ResponseT, - RequestT | MultiTurnRequestT[RequestT], + RequestT, ScheduledRequestInfo, SchedulerState, ] @@ -194,11 +194,11 @@ def __init__(self): async def sync_run_params( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], + requests: DatasetIterT[RequestT], strategy: SchedulingStrategy, constraints: dict[str, Constraint], ) -> tuple[ - Iterable[RequestT | MultiTurnRequestT[RequestT]], + DatasetIterT[RequestT], SchedulingStrategy, dict[str, Constraint], ]: @@ -250,7 +250,7 @@ async def sync_run_end( ) -> AsyncIterator[ tuple[ ResponseT, - RequestT | MultiTurnRequestT[RequestT], + RequestT, ScheduledRequestInfo, SchedulerState, ] diff --git a/src/guidellm/scheduler/objects.py b/src/guidellm/scheduler/objects.py index a58d9225e..e7d4c6c7a 100644 --- a/src/guidellm/scheduler/objects.py +++ b/src/guidellm/scheduler/objects.py @@ -11,7 +11,7 @@ import time import uuid -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterable from typing import ( Any, ClassVar, @@ -34,10 +34,10 @@ __all__ = [ "BackendInterface", "BackendT", + "DatasetIterT", "HistoryT", "MeasuredRequestTimings", - "MultiTurnRequestT", - "MultiTurnT", + "RequestDataT", "RequestSchedulerTimings", "RequestT", "ResponseT", @@ -47,36 +47,32 @@ "SchedulerState", "SchedulerUpdateAction", "SchedulerUpdateActionProgress", - "TurnT", ] RequestT = TypeVar("RequestT") """Generic request object type for scheduler processing.""" -# TODO: Remove -MultiTurnRequestT = RequestT - ResponseT = TypeVar("ResponseT") """Generic response object type returned by backend processing.""" -TurnT = TypeAliasType( - "TurnT", +RequestDataT = TypeAliasType( + "RequestDataT", tuple[RequestT, "ScheduledRequestAugmentation", "ScheduledRequestInfo"], type_params=(RequestT,), ) - -MultiTurnT = TypeAliasType( - "MultiTurnT", - list[TurnT[RequestT]], - type_params=(RequestT,), -) -"""Multi-turn request structure supporting conversation history with optional delays.""" +"""Request including external metadata and scheduling config.""" HistoryT = TypeAliasType( "HistoryT", list[tuple[RequestT, ResponseT]], type_params=(RequestT, ResponseT), ) +"""Record of requests + responses in conversation.""" + + +DatasetIterT = TypeAliasType( + "DatasetIterT", Iterable[Iterable[tuple[RequestT, float]]], type_params=(RequestT,) +) class SchedulerMessagingPydanticRegistry(RegistryMixin[RegistryObjT]): diff --git a/src/guidellm/scheduler/scheduler.py b/src/guidellm/scheduler/scheduler.py index e7d8b2c68..43948d18a 100644 --- a/src/guidellm/scheduler/scheduler.py +++ b/src/guidellm/scheduler/scheduler.py @@ -10,7 +10,7 @@ from __future__ import annotations -from collections.abc import AsyncIterator, Iterable +from collections.abc import AsyncIterator from typing import Any, Generic from guidellm.scheduler.constraints import ( @@ -20,7 +20,7 @@ from guidellm.scheduler.environments import Environment, NonDistributedEnvironment from guidellm.scheduler.objects import ( BackendInterface, - MultiTurnRequestT, + DatasetIterT, RequestT, ResponseT, ScheduledRequestInfo, @@ -66,7 +66,7 @@ class Scheduler( async def run( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]], + requests: DatasetIterT[RequestT], backend: BackendInterface[RequestT, ResponseT], strategy: SchedulingStrategy, env: Environment | None, diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 81b9e2b1d..4c5903fb2 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -32,8 +32,7 @@ from guidellm.scheduler.objects import ( BackendInterface, HistoryT, - MultiTurnRequestT, - MultiTurnT, + RequestDataT, RequestT, ResponseT, ScheduledRequestAugmentation, @@ -54,7 +53,7 @@ "ProcessRequestT", tuple[ HistoryT[RequestT, ResponseT], - MultiTurnT[RequestT], + list[RequestDataT[RequestT]], ScheduledRequestAugmentation, ], type_params=(RequestT, ResponseT), @@ -87,11 +86,8 @@ class WorkerProcess(Generic[RequestT, ResponseT]): def __init__( self, messaging: InterProcessMessaging[ - tuple[ - ResponseT | None, - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo, - ], + tuple[ResponseT | None, RequestT, ScheduledRequestInfo], + list[RequestDataT[RequestT]], ], backend: BackendInterface[RequestT, ResponseT], request_timings: ScheduledRequestTimings, @@ -132,7 +128,7 @@ def __init__( self.backend_started = False self.messaging_started = False self.turns_queue: list[ - tuple[HistoryT[RequestT, ResponseT], MultiTurnT[RequestT]] + tuple[HistoryT[RequestT, ResponseT], list[RequestDataT[RequestT]]] ] = [] def run(self): @@ -332,7 +328,7 @@ async def _cancel_requests_loop(self): self._send_update("cancelled", None, request, request_info) async def _process_next_request(self) -> ProcessRequestT[RequestT, ResponseT]: - conversation: MultiTurnT[RequestT] = [] + conversation: list[RequestDataT[RequestT]] = [] history: HistoryT[RequestT, ResponseT] = [] request: RequestT | None = None request_info: ScheduledRequestInfo | None = None @@ -409,7 +405,7 @@ async def _process_next_request(self) -> ProcessRequestT[RequestT, ResponseT]: async def _wait_then_requeue( self, history: HistoryT[RequestT, ResponseT], - conversation: MultiTurnT[RequestT], + conversation: list[RequestDataT[RequestT]], aug: ScheduledRequestAugmentation, ): try: @@ -427,7 +423,7 @@ def _send_update( "pending", "in_progress", "completed", "errored", "cancelled" ], response: ResponseT | None, - request: RequestT | MultiTurnRequestT[RequestT], + request: RequestT, request_info: ScheduledRequestInfo, ): prev_status = request_info.status diff --git a/src/guidellm/scheduler/worker_group.py b/src/guidellm/scheduler/worker_group.py index 221e95e1e..296152a8c 100644 --- a/src/guidellm/scheduler/worker_group.py +++ b/src/guidellm/scheduler/worker_group.py @@ -25,8 +25,8 @@ from guidellm.scheduler.constraints import Constraint, RequestsExhaustedConstraint from guidellm.scheduler.objects import ( BackendInterface, - MultiTurnRequestT, - MultiTurnT, + DatasetIterT, + RequestDataT, RequestT, ResponseT, ScheduledRequestAugmentation, @@ -83,8 +83,8 @@ class WorkerProcessGroup(Generic[RequestT, ResponseT]): def __init__( self, - requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, - cycle_requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] | None, + requests: DatasetIterT[RequestT] | None, + cycle_requests: DatasetIterT[RequestT] | None, backend: BackendInterface[RequestT, ResponseT], strategy: SchedulingStrategy, constraints: dict[str, Constraint], @@ -131,16 +131,8 @@ def __init__( # Scheduler and messaging state, created in start self.state: WorkerGroupState[ResponseT, RequestT] = None self.messaging: InterProcessMessaging[ - tuple[ - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo, - ], - tuple[ - ResponseT | None, - RequestT | MultiTurnRequestT[RequestT], - ScheduledRequestInfo, - SchedulerState, - ], + list[RequestDataT[RequestT]], + tuple[ResponseT | None, RequestT, ScheduledRequestInfo, SchedulerState], ] = None async def create_processes(self): @@ -473,9 +465,9 @@ def __init__( def requests_generator( self, - requests: Iterable[Iterable[tuple[RequestT, float]]] | None, - cycle_requests: Iterable[Iterable[tuple[RequestT, float]]] | None, - ) -> Generator[MultiTurnT[RequestT], None, None]: + requests: DatasetIterT[RequestT] | None, + cycle_requests: DatasetIterT[RequestT] | None, + ) -> Generator[list[RequestDataT[RequestT]], None, None]: """ Generate request-info pairs for worker processing with constraint evaluation. @@ -544,12 +536,12 @@ def received_callback( self, update: tuple[ ResponseT | None, - RequestT | MultiTurnRequestT, + RequestT, ScheduledRequestInfo, ], ) -> tuple[ ResponseT | None, - RequestT | MultiTurnRequestT, + RequestT, ScheduledRequestInfo, SchedulerState, ]: diff --git a/tests/unit/scheduler/test_objects.py b/tests/unit/scheduler/test_objects.py index df794ff8f..f76fcfd1d 100644 --- a/tests/unit/scheduler/test_objects.py +++ b/tests/unit/scheduler/test_objects.py @@ -7,13 +7,11 @@ import pytest from pydantic import ValidationError -from typing_extensions import TypeAliasType from guidellm.scheduler import ( BackendInterface, BackendT, MeasuredRequestTimings, - MultiTurnRequestT, RequestSchedulerTimings, RequestT, ResponseT, @@ -49,20 +47,6 @@ def test_backend_t(): assert BackendT.__constraints__ == () -def test_multi_turn_request_t(): - """Validate MultiTurnRequestT is a TypeAliasType for multi-turn requests.""" - assert isinstance(MultiTurnRequestT, TypeAliasType) - assert MultiTurnRequestT.__name__ == "MultiTurnRequestT" - - value = MultiTurnRequestT.__value__ - assert hasattr(value, "__origin__") - assert value.__origin__ is Union - - type_params = getattr(MultiTurnRequestT, "__type_params__", ()) - assert len(type_params) == 1 - assert type_params[0].__name__ == "RequestT" - - class TestBackendInterface: """Test the BackendInterface abstract base class.""" From eade3a2259d2dac1664b067a5b50d80433126b9a Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Tue, 10 Jun 2025 13:54:50 -0400 Subject: [PATCH 11/20] Add fixed prefix option to synthetic data Signed-off-by: Samuel Monson Add prefix before decode Signed-off-by: Samuel Monson Add unique single-token prefix to every request Co-authored-by: Mehul Co-authored-by: Samuel Monson Signed-off-by: Samuel Monson --- src/guidellm/dataset/__init__.py | 2 + src/guidellm/dataset/synthetic.py | 61 +++++++++++++++++++++++++++---- 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/src/guidellm/dataset/__init__.py b/src/guidellm/dataset/__init__.py index b90b72ff9..009ddf408 100644 --- a/src/guidellm/dataset/__init__.py +++ b/src/guidellm/dataset/__init__.py @@ -4,6 +4,7 @@ from .hf_datasets import HFDatasetsCreator from .in_memory import InMemoryDatasetCreator from .synthetic import ( + PrefixBucketConfig, SyntheticDatasetConfig, SyntheticDatasetCreator, SyntheticTextItemsGenerator, @@ -15,6 +16,7 @@ "FileDatasetCreator", "HFDatasetsCreator", "InMemoryDatasetCreator", + "PrefixBucketConfig", "SyntheticDatasetConfig", "SyntheticDatasetCreator", "SyntheticTextItemsGenerator", diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py index 06972643b..dd93e9088 100644 --- a/src/guidellm/dataset/synthetic.py +++ b/src/guidellm/dataset/synthetic.py @@ -1,6 +1,6 @@ import json import random -from collections.abc import Iterable, Iterator +from collections.abc import Iterable, Iterator, Sequence from itertools import cycle from pathlib import Path from typing import Any, Optional, TypedDict, Union @@ -19,18 +19,36 @@ from guidellm.utils import EndlessTextCreator, IntegerRangeSampler, check_load_processor __all__ = [ + "PrefixBucketConfig", "SyntheticDatasetConfig", "SyntheticDatasetCreator", "SyntheticTextItemsGenerator", ] -class SyntheticDatasetConfig(BaseModel): +class PrefixBucketConfig(BaseModel): + bucket_weight: int = Field( + description="Weight of this bucket in the overall distribution.", + gt=0, + default=100, + ) + prefix_count: int = Field( + description="The number of unique prefixs to generate for this bucket.", + ge=1, + default=1, + ) prefix_tokens: int = Field( - description="The number of shared prefix tokens to prepend to each prompt.", + description="The number of prefix tokens per-prompt for this bucket.", ge=0, default=0, ) + + +class SyntheticDatasetConfig(BaseModel): + prefix_buckets: Optional[list[PrefixBucketConfig]] = Field( + description="Buckets for the prefix tokens distribution.", + default=None, + ) prompt_tokens: int = Field( description="The average number of text tokens generated for prompts.", gt=0, @@ -190,11 +208,9 @@ def __iter__( ) # ensure diff distribution from output tokens rand = random.Random(self.random_seed + 2) # noqa: S311 + shared_prefix_iter = iter(self._create_prefixes(rand)) unique_prefix_iter = cycle(self.processor.get_vocab().values()) - prefix_index = rand.randint(0, len(self.text_creator.words)) - prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index) - for _, turns in zip(range(self.config.samples), turns_sampler): row: SyntheticDatasetRow = { "prompt": [], @@ -207,6 +223,7 @@ def __iter__( output_tokens_sampler, ): start_index = rand.randint(0, len(self.text_creator.words)) + prefix_tokens = next(shared_prefix_iter, []) # Append the prefix tokens only for the first turn if i == 0: prompt_text = self.processor.decode( @@ -217,7 +234,7 @@ def __iter__( skip_special_tokens=True, ) row["prompt"].append(prompt_text) - row["prompt_tokens_count"].append(self.config.prefix_tokens + prompt_tokens) + row["prompt_tokens_count"].append(len(prefix_tokens) + prompt_tokens) row["output_tokens_count"].append(output_tokens) else: prompt_text = self.processor.decode( @@ -232,6 +249,36 @@ def __iter__( yield row + def _rand_start_index(self, rand: random.Random) -> int: + """Generate a random start index for text generation.""" + return rand.randint(0, len(self.text_creator.words) - 1) + + def _create_prefixes(self, rand: random.Random) -> Sequence[list[int]]: + """Create an iterator for shared prefix tokens.""" + buckets = self.config.prefix_buckets + + if not buckets: + return [] + + total_weight = sum(bucket.bucket_weight for bucket in buckets) + if total_weight <= 0: + raise ValueError("Total weight of prefix buckets must be greater than 0.") + + prompts = [] + for bucket in buckets: + for _ in range(bucket.prefix_count): + start_index = self._rand_start_index(rand) + prompt_tokens = self._create_prompt(bucket.prefix_tokens, start_index) + sample_percent = ( + bucket.bucket_weight / bucket.prefix_count / total_weight + ) + sample_count = sample_percent * self.config.samples + for _ in range(int(round(sample_count))): + prompts.append(prompt_tokens) + + rand.shuffle(prompts) + return prompts + def _create_prompt( self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None ) -> list[int]: From 5795c02f15966c7c845f6d552e1c9511dae7fc92 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Tue, 19 Aug 2025 15:48:19 -0400 Subject: [PATCH 12/20] Update tests for new prefix patch and reduce the number of mocks Signed-off-by: Samuel Monson --- tests/unit/dataset/test_synthetic.py | 200 +++++++++------------------ 1 file changed, 66 insertions(+), 134 deletions(-) diff --git a/tests/unit/dataset/test_synthetic.py b/tests/unit/dataset/test_synthetic.py index e3110fa3c..b249ab300 100644 --- a/tests/unit/dataset/test_synthetic.py +++ b/tests/unit/dataset/test_synthetic.py @@ -11,6 +11,7 @@ import yaml from guidellm.dataset.synthetic import ( + PrefixBucketConfig, SyntheticDatasetConfig, SyntheticDatasetCreator, SyntheticTextItemsGenerator, @@ -29,8 +30,12 @@ def test_config_creation_with_all_params(self): ### WRITTEN BY AI ### """ + prefix_bucket = PrefixBucketConfig( + bucket_weight=100, prefix_count=1, prefix_tokens=5 + ) + config = SyntheticDatasetConfig( - prefix_tokens=5, + prefix_buckets=[prefix_bucket], prompt_tokens=100, prompt_tokens_stdev=10, prompt_tokens_min=50, @@ -43,7 +48,7 @@ def test_config_creation_with_all_params(self): source="custom_text.txt", ) - assert config.prefix_tokens == 5 + assert config.prefix_buckets[0].prefix_tokens == 5 # type: ignore [index] assert config.prompt_tokens == 100 assert config.prompt_tokens_stdev == 10 assert config.prompt_tokens_min == 50 @@ -67,7 +72,9 @@ def test_parse_json_string(self): "output_tokens": 25, "samples": 200, "source": "test.txt", - "prefix_tokens": 10, + "prefix_buckets": [ + {"bucket_weight": 100, "prefix_count": 1, "prefix_tokens": 10} + ], } ) @@ -77,7 +84,7 @@ def test_parse_json_string(self): assert config.output_tokens == 25 assert config.samples == 200 assert config.source == "test.txt" - assert config.prefix_tokens == 10 + assert config.prefix_buckets[0].prefix_tokens == 10 # type: ignore [index] @pytest.mark.regression def test_parse_key_value_pairs(self): @@ -85,7 +92,7 @@ def test_parse_key_value_pairs(self): ### WRITTEN BY AI ### """ - kv_str = "prompt_tokens=80,output_tokens=30,samples=300,source=data.txt,prefix_tokens=5" # noqa: E501 + kv_str = "prompt_tokens=80,output_tokens=30,samples=300,source=data.txt" config = SyntheticDatasetConfig.parse_str(kv_str) @@ -93,7 +100,7 @@ def test_parse_key_value_pairs(self): assert config.output_tokens == 30 assert config.samples == 300 assert config.source == "data.txt" - assert config.prefix_tokens == 5 + assert config.prefix_buckets is None @pytest.mark.sanity def test_parse_yaml_file(self): @@ -106,7 +113,9 @@ def test_parse_yaml_file(self): "output_tokens": 15, "samples": 100, "source": "yaml_test.txt", - "prefix_tokens": 3, + "prefix_buckets": [ + {"bucket_weight": 100, "prefix_count": 1, "prefix_tokens": 3} + ], } with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: @@ -120,7 +129,7 @@ def test_parse_yaml_file(self): assert config.output_tokens == 15 assert config.samples == 100 assert config.source == "yaml_test.txt" - assert config.prefix_tokens == 3 + assert config.prefix_buckets[0].prefix_tokens == 3 # type: ignore [index] finally: Path(yaml_path).unlink() @@ -134,7 +143,9 @@ def test_parse_config_file(self): "prompt_tokens": 90, "output_tokens": 35, "samples": 150, - "prefix_tokens": 2, + "prefix_buckets": [ + {"bucket_weight": 100, "prefix_count": 1, "prefix_tokens": 2} + ], } with tempfile.NamedTemporaryFile(mode="w", suffix=".config", delete=False) as f: @@ -147,7 +158,7 @@ def test_parse_config_file(self): assert config.prompt_tokens == 90 assert config.output_tokens == 35 assert config.samples == 150 - assert config.prefix_tokens == 2 + assert config.prefix_buckets[0].prefix_tokens == 2 # type: ignore [index] finally: Path(config_path).unlink() @@ -194,8 +205,9 @@ def test_validation_positive_values(self): with pytest.raises(ValueError): SyntheticDatasetConfig(prompt_tokens=20, output_tokens=10, samples=0) + # Test negative prefix tokens via PrefixBucketConfig validation with pytest.raises(ValueError): - SyntheticDatasetConfig(prompt_tokens=20, output_tokens=10, prefix_tokens=-1) + PrefixBucketConfig(prefix_tokens=-1) @pytest.mark.regression def test_validation_optional_positive_values(self): @@ -279,7 +291,7 @@ def mock_tokenizer(self): """ tokenizer = Mock() tokenizer.get_vocab.return_value = {f"token_{i}": i for i in range(1000)} - tokenizer.encode.side_effect = lambda text: [1, 2, 3] * (len(text) // 10 + 1) + tokenizer.encode.side_effect = lambda text: list(range(len(text.split()))) tokenizer.decode.side_effect = ( lambda tokens, skip_special_tokens=False: " ".join( f"token_{t}" for t in tokens[:5] @@ -287,6 +299,22 @@ def mock_tokenizer(self): ) return tokenizer + @pytest.fixture + def mock_integer_range_sampler(self): + """Fixture to provide a mocked IntegerRangeSampler. + + ### WRITTEN BY AI ### + """ + with patch("guidellm.dataset.synthetic.IntegerRangeSampler") as mock_sampler: + # Default side effect for basic iteration + def mock_sampler_side_effect(*args, **kwargs): + mock_instance = Mock() + mock_instance.__iter__ = Mock(return_value=iter([15, 15, 15, 15, 15])) + return mock_instance + + mock_sampler.side_effect = mock_sampler_side_effect + yield mock_sampler + @pytest.fixture def simple_config(self): """Fixture for simple configuration. @@ -306,8 +334,12 @@ def config_with_prefix(self): ### WRITTEN BY AI ### """ + prefix_bucket = PrefixBucketConfig( + bucket_weight=100, prefix_count=1, prefix_tokens=3 + ) + return SyntheticDatasetConfig( - prefix_tokens=3, + prefix_buckets=[prefix_bucket], prompt_tokens=15, output_tokens=10, samples=5, @@ -352,29 +384,16 @@ def test_generator_initialization( mock_text_creator.assert_called_once_with(data=simple_config.source) @pytest.mark.smoke - @patch("guidellm.dataset.synthetic.EndlessTextCreator") - @patch("guidellm.dataset.synthetic.IntegerRangeSampler") def test_basic_iteration( - self, mock_sampler, mock_text_creator, simple_config, mock_tokenizer + self, + mock_integer_range_sampler, + simple_config, + mock_tokenizer, ): """Test basic iteration functionality. ### WRITTEN BY AI ### """ - # Setup mocks - mock_text_creator_instance = Mock() - mock_text_creator_instance.words = ["word1", "word2", "word3"] * 100 - mock_text_creator_instance.create_text.return_value = "sample text" - mock_text_creator.return_value = mock_text_creator_instance - - # Mock IntegerRangeSampler to return iterators - def mock_sampler_side_effect(*args, **kwargs): - mock_instance = Mock() - mock_instance.__iter__ = Mock(return_value=iter([15, 15, 15, 15, 15])) - return mock_instance - - mock_sampler.side_effect = mock_sampler_side_effect - generator = SyntheticTextItemsGenerator( simple_config, mock_tokenizer, random_seed=42 ) @@ -394,28 +413,19 @@ def mock_sampler_side_effect(*args, **kwargs): assert isinstance(item["output_tokens_count"], int) @pytest.mark.sanity - @patch("guidellm.dataset.synthetic.EndlessTextCreator") - def test_create_prompt_method( - self, mock_text_creator, simple_config, mock_tokenizer - ): + def test_create_prompt_method(self, simple_config, mock_tokenizer): """Test _create_prompt method. ### WRITTEN BY AI ### """ - mock_text_creator_instance = Mock() - mock_text_creator_instance.words = ["word"] * 100 - mock_text_creator_instance.create_text.return_value = "test text" - mock_text_creator.return_value = mock_text_creator_instance - - mock_tokenizer.encode.return_value = [1, 2, 3] - generator = SyntheticTextItemsGenerator( simple_config, mock_tokenizer, random_seed=42 ) # Test normal case result = generator._create_prompt(5, 0, 42) - assert result == [42, 1, 2, 3] + assert result[0] == 42 # Unique prefix token + assert len(result) == 5 # Test zero tokens result = generator._create_prompt(0, 0, 42) @@ -423,30 +433,14 @@ def test_create_prompt_method( # Test without unique prefix result = generator._create_prompt(3, 0) - assert result == [1, 2, 3] + assert len(result) == 3 @pytest.mark.regression - @patch("guidellm.dataset.synthetic.EndlessTextCreator") - def test_create_prompt_binary_search( - self, mock_text_creator, simple_config, mock_tokenizer - ): + def test_create_prompt_binary_search(self, simple_config, mock_tokenizer): """Test binary search logic in _create_prompt. ### WRITTEN BY AI ### """ - mock_text_creator_instance = Mock() - mock_text_creator_instance.words = ["word"] * 1000 - mock_text_creator_instance.create_text.side_effect = lambda start, length: ( - "text " * max(1, length // 4) - ).strip() - mock_text_creator.return_value = mock_text_creator_instance - - # Mock tokenizer to return different lengths based on input - def mock_encode(text): - return [1] * len(text.split()) - - mock_tokenizer.encode.side_effect = mock_encode - generator = SyntheticTextItemsGenerator( simple_config, mock_tokenizer, random_seed=42 ) @@ -456,25 +450,13 @@ def mock_encode(text): assert len(result) >= 4 # Should include prefix + some tokens @pytest.mark.sanity - @patch("guidellm.dataset.synthetic.EndlessTextCreator") - @patch("guidellm.dataset.synthetic.IntegerRangeSampler") def test_prefix_tokens_integration( - self, mock_sampler, mock_text_creator, config_with_prefix, mock_tokenizer + self, mock_integer_range_sampler, config_with_prefix, mock_tokenizer ): """Test integration with prefix tokens. ### WRITTEN BY AI ### """ - # Setup mocks - mock_text_creator_instance = Mock() - mock_text_creator_instance.words = ["word"] * 100 - mock_text_creator_instance.create_text.return_value = "sample text" - mock_text_creator.return_value = mock_text_creator_instance - - mock_sampler_instance = Mock() - mock_sampler_instance.__iter__ = Mock(return_value=iter([15, 15, 15, 15, 15])) - mock_sampler.return_value = mock_sampler_instance - generator = SyntheticTextItemsGenerator( config_with_prefix, mock_tokenizer, random_seed=42 ) @@ -483,40 +465,19 @@ def test_prefix_tokens_integration( # Verify prompt_tokens_count includes prefix for item in items: - assert item["prompt_tokens_count"] == config_with_prefix.prefix_tokens + 15 + assert ( + item["prompt_tokens_count"] + == config_with_prefix.prefix_buckets[0].prefix_tokens + 15 + ) @pytest.mark.regression - @patch("guidellm.dataset.synthetic.EndlessTextCreator") - @patch("guidellm.dataset.synthetic.IntegerRangeSampler") def test_random_seeding_consistency( - self, mock_sampler, mock_text_creator, simple_config, mock_tokenizer + self, mock_integer_range_sampler, simple_config, mock_tokenizer ): """Test that same seed produces consistent results. ### WRITTEN BY AI ### """ - # Setup mocks - mock_text_creator_instance = Mock() - mock_text_creator_instance.words = ["word"] * 100 - mock_text_creator_instance.create_text.return_value = "sample text" - mock_text_creator.return_value = mock_text_creator_instance - - # Create consistent mock sampler behavior - call_count = 0 - - def mock_sampler_side_effect(*args, **kwargs): - nonlocal call_count - mock_instance = Mock() - # Return same sequence for both prompt and output tokens - if call_count % 2 == 0: # prompt tokens - mock_instance.__iter__ = Mock(return_value=iter([15, 16, 17, 18, 19])) - else: # output tokens - mock_instance.__iter__ = Mock(return_value=iter([10, 11, 12, 13, 14])) - call_count += 1 - return mock_instance - - mock_sampler.side_effect = mock_sampler_side_effect - # Create two generators with same seed generator1 = SyntheticTextItemsGenerator( simple_config, mock_tokenizer, random_seed=42 @@ -528,7 +489,7 @@ def mock_sampler_side_effect(*args, **kwargs): items1 = list(generator1) items2 = list(generator2) - # Results should be identical with same seed + # With same seed and deterministic mocks, results should be identical assert len(items1) == len(items2) for item1, item2 in zip(items1, items2): assert item1["prompt"] == item2["prompt"] @@ -536,34 +497,13 @@ def mock_sampler_side_effect(*args, **kwargs): assert item1["output_tokens_count"] == item2["output_tokens_count"] @pytest.mark.regression - @patch("guidellm.dataset.synthetic.EndlessTextCreator") - @patch("guidellm.dataset.synthetic.IntegerRangeSampler") def test_variance_configuration( - self, mock_sampler, mock_text_creator, complex_config, mock_tokenizer + self, mock_integer_range_sampler, complex_config, mock_tokenizer ): """Test that variance configuration is properly used. ### WRITTEN BY AI ### """ - # Setup mocks - mock_text_creator_instance = Mock() - mock_text_creator_instance.words = ["word"] * 100 - mock_text_creator_instance.create_text.return_value = "sample text" - mock_text_creator.return_value = mock_text_creator_instance - - # Fix tokenizer mock to handle the create_text return properly - mock_tokenizer.encode.side_effect = ( - lambda text: [1, 2, 3] if isinstance(text, str) else [1, 2, 3] - ) - - # Setup mock sampler to track calls - def mock_sampler_side_effect(*args, **kwargs): - mock_instance = Mock() - mock_instance.__iter__ = Mock(return_value=iter([20, 18, 22, 19, 21] * 2)) - return mock_instance - - mock_sampler.side_effect = mock_sampler_side_effect - generator = SyntheticTextItemsGenerator( complex_config, mock_tokenizer, random_seed=42 ) @@ -573,10 +513,10 @@ def mock_sampler_side_effect(*args, **kwargs): next(generator_iter) # Verify that IntegerRangeSampler is called with correct parameters - assert mock_sampler.call_count == 2 + assert mock_integer_range_sampler.call_count == 2 # Check prompt tokens sampler call - prompt_call = mock_sampler.call_args_list[0] + prompt_call = mock_integer_range_sampler.call_args_list[0] assert prompt_call[1]["average"] == complex_config.prompt_tokens assert prompt_call[1]["variance"] == complex_config.prompt_tokens_stdev assert prompt_call[1]["min_value"] == complex_config.prompt_tokens_min @@ -584,7 +524,7 @@ def mock_sampler_side_effect(*args, **kwargs): assert prompt_call[1]["random_seed"] == 42 # Check output tokens sampler call - output_call = mock_sampler.call_args_list[1] + output_call = mock_integer_range_sampler.call_args_list[1] assert output_call[1]["average"] == complex_config.output_tokens assert output_call[1]["variance"] == complex_config.output_tokens_stdev assert output_call[1]["min_value"] == complex_config.output_tokens_min @@ -592,19 +532,11 @@ def mock_sampler_side_effect(*args, **kwargs): assert output_call[1]["random_seed"] == 43 # 42 + 1 @pytest.mark.regression - @patch("guidellm.dataset.synthetic.EndlessTextCreator") - def test_unique_prefix_generation( - self, mock_text_creator, simple_config, mock_tokenizer - ): + def test_unique_prefix_generation(self, simple_config, mock_tokenizer): """Test that unique prefixes are generated for each request. ### WRITTEN BY AI ### """ - mock_text_creator_instance = Mock() - mock_text_creator_instance.words = ["word"] * 100 - mock_text_creator_instance.create_text.return_value = "sample text" - mock_text_creator.return_value = mock_text_creator_instance - # Mock the cycle to return predictable values with patch("guidellm.dataset.synthetic.cycle") as mock_cycle: mock_cycle.return_value = iter([100, 101, 102, 103, 104]) From 66b531145638392f06935498675396ced4e99604 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Tue, 19 Aug 2025 17:28:48 -0400 Subject: [PATCH 13/20] Add more prefix bucket testcases Signed-off-by: Samuel Monson --- tests/unit/dataset/test_synthetic.py | 220 ++++++++++++++++++++++++++- 1 file changed, 218 insertions(+), 2 deletions(-) diff --git a/tests/unit/dataset/test_synthetic.py b/tests/unit/dataset/test_synthetic.py index b249ab300..080fcbfbc 100644 --- a/tests/unit/dataset/test_synthetic.py +++ b/tests/unit/dataset/test_synthetic.py @@ -18,6 +18,76 @@ ) +class TestPrefixBucketConfig: + """Test cases for PrefixBucketConfig class. + + ### WRITTEN BY AI ### + """ + + @pytest.mark.smoke + def test_creation_with_valid_params(self): + """Test creating PrefixBucketConfig with valid parameters. + + ### WRITTEN BY AI ### + """ + config = PrefixBucketConfig(bucket_weight=100, prefix_count=1, prefix_tokens=5) + + assert config.bucket_weight == 100 + assert config.prefix_count == 1 + assert config.prefix_tokens == 5 + + @pytest.mark.sanity + def test_creation_with_negative_values(self): + """Test creating PrefixBucketConfig with negative values raises ValueError. + + ### WRITTEN BY AI ### + """ + with pytest.raises(ValueError): + PrefixBucketConfig(bucket_weight=-10, prefix_count=1, prefix_tokens=5) + + with pytest.raises(ValueError): + PrefixBucketConfig(bucket_weight=100, prefix_count=-1, prefix_tokens=5) + + with pytest.raises(ValueError): + PrefixBucketConfig(bucket_weight=100, prefix_count=1, prefix_tokens=-5) + + @pytest.mark.regression + def test_prefix_bucket_zero_weight_error(self): + """Test that zero total weight raises an error. + + ### WRITTEN BY AI ### + """ + # Test validation error for creating PrefixBucketConfig with weight=0 + with pytest.raises(ValueError): + PrefixBucketConfig(bucket_weight=0, prefix_count=1, prefix_tokens=2) + + @pytest.mark.sanity + def test_prefix_bucket_config_validation(self): + """Test PrefixBucketConfig validation. + + ### WRITTEN BY AI ### + """ + # Test valid config + valid_config = PrefixBucketConfig( + bucket_weight=50, prefix_count=2, prefix_tokens=3 + ) + assert valid_config.bucket_weight == 50 + assert valid_config.prefix_count == 2 + assert valid_config.prefix_tokens == 3 + + # Test invalid bucket_weight + with pytest.raises(ValueError): + PrefixBucketConfig(bucket_weight=0, prefix_count=1, prefix_tokens=2) + + # Test invalid prefix_count + with pytest.raises(ValueError): + PrefixBucketConfig(bucket_weight=100, prefix_count=0, prefix_tokens=2) + + # Test invalid prefix_tokens + with pytest.raises(ValueError): + PrefixBucketConfig(bucket_weight=100, prefix_count=1, prefix_tokens=-1) + + class TestSyntheticDatasetConfig: """Test cases for SyntheticDatasetConfig class. @@ -306,10 +376,11 @@ def mock_integer_range_sampler(self): ### WRITTEN BY AI ### """ with patch("guidellm.dataset.synthetic.IntegerRangeSampler") as mock_sampler: - # Default side effect for basic iteration + # Side effect for basic iteration with enough values for larger tests def mock_sampler_side_effect(*args, **kwargs): mock_instance = Mock() - mock_instance.__iter__ = Mock(return_value=iter([15, 15, 15, 15, 15])) + # Provide enough values for tests (up to 20 items) + mock_instance.__iter__ = Mock(return_value=iter([15] * 20)) return mock_instance mock_sampler.side_effect = mock_sampler_side_effect @@ -346,6 +417,45 @@ def config_with_prefix(self): source="The quick brown fox jumps over the lazy dog.", ) + @pytest.fixture + def config_with_multiple_prefix_buckets(self): + """Fixture for configuration with multiple prefix buckets. + + ### WRITTEN BY AI ### + """ + prefix_bucket1 = PrefixBucketConfig( + bucket_weight=60, prefix_count=1, prefix_tokens=2 + ) + prefix_bucket2 = PrefixBucketConfig( + bucket_weight=40, prefix_count=1, prefix_tokens=4 + ) + + return SyntheticDatasetConfig( + prefix_buckets=[prefix_bucket1, prefix_bucket2], + prompt_tokens=10, + output_tokens=5, + samples=10, + source="The quick brown fox jumps over the lazy dog.", + ) + + @pytest.fixture + def config_with_multiple_prefix_counts(self): + """Fixture for configuration with prefix_count > 1. + + ### WRITTEN BY AI ### + """ + prefix_bucket = PrefixBucketConfig( + bucket_weight=100, prefix_count=3, prefix_tokens=2 + ) + + return SyntheticDatasetConfig( + prefix_buckets=[prefix_bucket], + prompt_tokens=8, + output_tokens=4, + samples=6, + source="The quick brown fox jumps over the lazy dog.", + ) + @pytest.fixture def complex_config(self): """Fixture for complex configuration with variance. @@ -552,6 +662,112 @@ def test_unique_prefix_generation(self, simple_config, mock_tokenizer): # Verify cycle was called with vocab values mock_cycle.assert_called_once() + @pytest.mark.regression + def test_multiple_prefix_buckets_distribution( + self, + mock_integer_range_sampler, + config_with_multiple_prefix_buckets, + mock_tokenizer, + ): + """Test distribution across multiple prefix buckets with different weights. + + ### WRITTEN BY AI ### + """ + generator = SyntheticTextItemsGenerator( + config_with_multiple_prefix_buckets, mock_tokenizer, random_seed=42 + ) + + items = list(generator) + + # Verify we get the expected number of items + assert len(items) == config_with_multiple_prefix_buckets.samples + + # Verify that prefix tokens are added to prompt_tokens_count + # Since we have buckets with 2 and 4 prefix tokens, and the mock returns 15 + # prompt tokens, we should see prompt_tokens_count of either 17 or 19 + prefix_counts = [item["prompt_tokens_count"] for item in items] + assert all(count in [17, 19] for count in prefix_counts) + + # Calculate expected distribution based on weights + # Bucket 1: weight=60, prefix_count=1, prefix_tokens=2 + # Bucket 2: weight=40, prefix_count=1, prefix_tokens=4 + # Total weight = 100, samples = 10 + # Bucket 1: (60/1/100) * 10 = 6 samples with 17 tokens (2 prefix + 15 prompt) + # Bucket 2: (40/1/100) * 10 = 4 samples with 19 tokens (4 prefix + 15 prompt) + count_17 = prefix_counts.count(17) # 2 prefix tokens + count_19 = prefix_counts.count(19) # 4 prefix tokens + assert count_17 == 6 + assert count_19 == 4 + + @pytest.mark.regression + def test_multiple_prefix_counts( + self, + mock_integer_range_sampler, + config_with_multiple_prefix_counts, + mock_tokenizer, + ): + """Test prefix buckets with prefix_count > 1. + + ### WRITTEN BY AI ### + """ + generator = SyntheticTextItemsGenerator( + config_with_multiple_prefix_counts, mock_tokenizer, random_seed=42 + ) + + items = list(generator) + + # Verify we get the expected number of items + assert len(items) == config_with_multiple_prefix_counts.samples + + # All items should have 2 prefix tokens + 15 prompt tokens = 17 total + for item in items: + assert item["prompt_tokens_count"] == 17 + + @pytest.mark.sanity + def test_prefix_buckets_create_prefixes_method( + self, config_with_multiple_prefix_buckets, mock_tokenizer + ): + """Test the _create_prefixes method directly. + + ### WRITTEN BY AI ### + """ + generator = SyntheticTextItemsGenerator( + config_with_multiple_prefix_buckets, mock_tokenizer, random_seed=42 + ) + + # Test _create_prefixes method + rand = Mock() + rand.randint = Mock(return_value=0) + prefixes = generator._create_prefixes(rand) + + # Should return a sequence of prefix token lists + assert isinstance(prefixes, list) + assert len(prefixes) == 10 + + # Each prefix should be a list of integers + for prefix in prefixes: + assert isinstance(prefix, list) + assert all(isinstance(token, int) for token in prefix) + + @pytest.mark.regression + def test_empty_prefix_buckets( + self, mock_integer_range_sampler, simple_config, mock_tokenizer + ): + """Test behavior when prefix_buckets is None or empty. + + ### WRITTEN BY AI ### + """ + # Test with None prefix_buckets (simple_config has None) + generator = SyntheticTextItemsGenerator( + simple_config, mock_tokenizer, random_seed=42 + ) + + items = list(generator) + + # All items should have exactly the prompt tokens (no prefix) + for item in items: + assert item["prompt_tokens_count"] == 15 # Mock returns 15 + class TestSyntheticDatasetCreator: """Test cases for SyntheticDatasetCreator class. From 5ddb73cddbb9936b6a18112f79724172a5e38d26 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Tue, 30 Sep 2025 13:51:49 -0400 Subject: [PATCH 14/20] Append prefix tokens only to first turn Signed-off-by: Samuel Monson --- src/guidellm/dataset/synthetic.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py index dd93e9088..42ead9d0f 100644 --- a/src/guidellm/dataset/synthetic.py +++ b/src/guidellm/dataset/synthetic.py @@ -223,9 +223,9 @@ def __iter__( output_tokens_sampler, ): start_index = rand.randint(0, len(self.text_creator.words)) - prefix_tokens = next(shared_prefix_iter, []) # Append the prefix tokens only for the first turn if i == 0: + prefix_tokens = next(shared_prefix_iter, []) prompt_text = self.processor.decode( prefix_tokens + self._create_prompt( @@ -234,7 +234,9 @@ def __iter__( skip_special_tokens=True, ) row["prompt"].append(prompt_text) - row["prompt_tokens_count"].append(len(prefix_tokens) + prompt_tokens) + row["prompt_tokens_count"].append( + len(prefix_tokens) + prompt_tokens + ) row["output_tokens_count"].append(output_tokens) else: prompt_text = self.processor.decode( From 3902a287cbbacce100c6309dd2dd563f3dbfca98 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Tue, 7 Oct 2025 11:26:22 -0400 Subject: [PATCH 15/20] Backport new prefix_tokens bucketting algorithm --- src/guidellm/dataset/synthetic.py | 78 +++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py index 42ead9d0f..7f47db10a 100644 --- a/src/guidellm/dataset/synthetic.py +++ b/src/guidellm/dataset/synthetic.py @@ -1,6 +1,7 @@ import json +import math import random -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Iterable, Iterator from itertools import cycle from pathlib import Path from typing import Any, Optional, TypedDict, Union @@ -12,8 +13,9 @@ IterableDataset, IterableDatasetDict, ) -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator from transformers import PreTrainedTokenizerBase # type: ignore[import] +from typing_extensions import Self from guidellm.dataset.creator import ColumnInputTypes, DatasetCreator from guidellm.utils import EndlessTextCreator, IntegerRangeSampler, check_load_processor @@ -45,6 +47,10 @@ class PrefixBucketConfig(BaseModel): class SyntheticDatasetConfig(BaseModel): + model_config = ConfigDict( + extra="allow", + ) + prefix_buckets: Optional[list[PrefixBucketConfig]] = Field( description="Buckets for the prefix tokens distribution.", default=None, @@ -117,6 +123,26 @@ class SyntheticDatasetConfig(BaseModel): default="data:prideandprejudice.txt.gz", ) + @model_validator(mode="after") + def check_prefix_options(self) -> Self: + prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined] + prefix_tokens = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined] + if prefix_count is not None or prefix_tokens is not None: + if self.prefix_buckets: + raise ValueError( + "prefix_buckets is mutually exclusive" + " with prefix_count and prefix_tokens" + ) + + self.prefix_buckets = [ + PrefixBucketConfig( + prefix_count=prefix_count or 1, + prefix_tokens=prefix_tokens or 0, + ) + ] + + return self + @staticmethod def parse_str(data: Union[str, Path]) -> "SyntheticDatasetConfig": if ( @@ -207,8 +233,8 @@ def __iter__( random_seed=self.random_seed + 7, # ensure diff dist ) # ensure diff distribution from output tokens - rand = random.Random(self.random_seed + 2) # noqa: S311 - shared_prefix_iter = iter(self._create_prefixes(rand)) + rand = random.Random(self.random_seed + 3) # noqa: S311 + shared_prefix_iter = self._create_prefix_iter(rand) unique_prefix_iter = cycle(self.processor.get_vocab().values()) for _, turns in zip(range(self.config.samples), turns_sampler): @@ -255,31 +281,35 @@ def _rand_start_index(self, rand: random.Random) -> int: """Generate a random start index for text generation.""" return rand.randint(0, len(self.text_creator.words) - 1) - def _create_prefixes(self, rand: random.Random) -> Sequence[list[int]]: - """Create an iterator for shared prefix tokens.""" - buckets = self.config.prefix_buckets - - if not buckets: - return [] - - total_weight = sum(bucket.bucket_weight for bucket in buckets) - if total_weight <= 0: - raise ValueError("Total weight of prefix buckets must be greater than 0.") + def _create_prefix_iter(self, rand: random.Random) -> Iterator[list[int]]: + if not self.config.prefix_buckets: + while True: + yield [] - prompts = [] - for bucket in buckets: + # Increase weights to ensure an integer number of samples per per-prefix + least_common_prefix_count = math.lcm( + *(bucket.prefix_count for bucket in self.config.prefix_buckets) + ) + unnorm_weights = [ + least_common_prefix_count * bucket.bucket_weight // bucket.prefix_count + for bucket in self.config.prefix_buckets + ] + # Use GCD to reduce the weights to smallest integer ratio + common_divisor = math.gcd(*unnorm_weights) + + # Create prefix list maintaining the correct distribution + prefixes = [] + for bucket, weight in zip(self.config.prefix_buckets, unnorm_weights): + bucket_prefixes = [] for _ in range(bucket.prefix_count): start_index = self._rand_start_index(rand) prompt_tokens = self._create_prompt(bucket.prefix_tokens, start_index) - sample_percent = ( - bucket.bucket_weight / bucket.prefix_count / total_weight - ) - sample_count = sample_percent * self.config.samples - for _ in range(int(round(sample_count))): - prompts.append(prompt_tokens) + bucket_prefixes.append(prompt_tokens) + sample_count = weight // common_divisor + prefixes.extend(bucket_prefixes * sample_count) - rand.shuffle(prompts) - return prompts + while True: + yield rand.choice(prefixes) def _create_prompt( self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None From 98a0e372331f0c91e2ff77fa89f5cd54a6e88b7d Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Tue, 7 Oct 2025 11:46:45 -0400 Subject: [PATCH 16/20] Fix incorrect field name in validator --- src/guidellm/dataset/synthetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py index 7f47db10a..6d837de1c 100644 --- a/src/guidellm/dataset/synthetic.py +++ b/src/guidellm/dataset/synthetic.py @@ -126,7 +126,7 @@ class SyntheticDatasetConfig(BaseModel): @model_validator(mode="after") def check_prefix_options(self) -> Self: prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined] - prefix_tokens = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined] + prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined] if prefix_count is not None or prefix_tokens is not None: if self.prefix_buckets: raise ValueError( From 213d8014a9290b1230016a5fe103913b51a524cd Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Mon, 13 Oct 2025 11:36:26 -0400 Subject: [PATCH 17/20] Map history only from previous request --- src/guidellm/backends/openai.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/guidellm/backends/openai.py b/src/guidellm/backends/openai.py index acce5f88d..40b1262bb 100644 --- a/src/guidellm/backends/openai.py +++ b/src/guidellm/backends/openai.py @@ -16,7 +16,6 @@ import json import time from collections.abc import AsyncIterator -from itertools import chain from pathlib import Path from typing import Any, ClassVar, Optional, Union @@ -508,11 +507,13 @@ def _apply_history( Apply conversation history to the current request. """ - def turn_to_text(turn: tuple[GenerationRequest, GenerationResponse]) -> str: - req, res = turn - return f"{req.content}{res.value}" + if len(history) > 0: + last_request = history[-1][0] + last_response = history[-1][1] + request.content = ( + f"{last_request.content}{last_response.value} {request.content}" + ) - request.content = "".join(chain(map(turn_to_text, history), (request.content,))) return request def _build_headers( From 79caf54a77323319e6d50599172c80c82bcb2f79 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Mon, 20 Oct 2025 12:55:20 -0400 Subject: [PATCH 18/20] Only track the last response in multi-turn --- src/guidellm/scheduler/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index 4c5903fb2..b54398212 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -376,7 +376,7 @@ async def _process_next_request(self) -> ProcessRequestT[RequestT, ResponseT]: self._send_update("completed", response, request, request_info) # Record Turn - history.append((request, response)) + history = [(request, response)] response = request = request_info = None except asyncio.CancelledError: From 23f818668db8a321741228c89e7a66f2698e67fd Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Mon, 20 Oct 2025 16:20:28 -0400 Subject: [PATCH 19/20] Clear conversation on premature exit --- src/guidellm/scheduler/worker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index b54398212..1ea7cfb39 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -399,6 +399,8 @@ async def _process_next_request(self) -> ProcessRequestT[RequestT, ResponseT]: request_info.error = "Request was cancelled" request_info.scheduler_timings.resolve_end = time.time() self._send_update("cancelled", response, request, request_info) + # Clear conversation on premature exit + conversation = [] return history, conversation, aug From e6c7e559811589bacdf7a9cd857740241b20edf2 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Fri, 10 Oct 2025 09:36:09 -0400 Subject: [PATCH 20/20] Configurable max_tokens/max_completion_tokens key (#399) Makes the `max_tokens` request key configurable through an environment variable per endpoint type. Defaults to `max_tokens` for legacy `completions` and `max_completion_tokens` for `chat/completions` - Add the `GUIDELLM__OPENAI__MAX_OUTPUT_KEY` config option which is a dict mapping from route name -> output tokens key. Default is `{"text_completions": "max_tokens", "chat_completions": "max_completion_tokens"}` - - Closes #395 - Closes #269 - Related #210 --- - [x] "I certify that all code in this PR is my own, except as noted below." - [ ] Includes AI-assisted code completion - [ ] Includes code generated by an AI application - [ ] Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes `## WRITTEN BY AI ##`) --------- Signed-off-by: Tyler Michael Smith Signed-off-by: Samuel Monson Co-authored-by: Tyler Michael Smith --- src/guidellm/backends/openai.py | 9 ++++----- src/guidellm/settings.py | 4 ++++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/guidellm/backends/openai.py b/src/guidellm/backends/openai.py index 40b1262bb..97cfc9a4a 100644 --- a/src/guidellm/backends/openai.py +++ b/src/guidellm/backends/openai.py @@ -30,6 +30,7 @@ GenerationResponse, ) from guidellm.scheduler import HistoryT, ScheduledRequestInfo +from guidellm.settings import settings __all__ = ["OpenAIHTTPBackend", "UsageStats"] @@ -628,12 +629,10 @@ def _get_body( # Handle token limits max_tokens = max_output_tokens or self.max_output_tokens if max_tokens is not None: - body.update( - { - "max_tokens": max_tokens, - "max_completion_tokens": max_tokens, - } + max_output_key = settings.openai.max_output_key.get( + endpoint_type, "max_tokens" ) + body[max_output_key] = max_output_tokens # Set stop conditions only for request-level limits if max_output_tokens: body.update({"stop": None, "ignore_eos": True}) diff --git a/src/guidellm/settings.py b/src/guidellm/settings.py index 20d9ff96f..ab6bd96f3 100644 --- a/src/guidellm/settings.py +++ b/src/guidellm/settings.py @@ -89,6 +89,10 @@ class OpenAISettings(BaseModel): base_url: str = "http://localhost:8000" max_output_tokens: int = 16384 verify: bool = True + max_output_key: dict[str, str] = { + "text_completions": "max_tokens", + "chat_completions": "max_completion_tokens", + } class ReportGenerationSettings(BaseModel):