diff --git a/workers/openai/worker.py b/workers/openai/worker.py index 6cf17f0..995fb3d 100644 --- a/workers/openai/worker.py +++ b/workers/openai/worker.py @@ -28,6 +28,12 @@ MODEL_INFO_LOG_MSGS = [ nltk.download("words") WORD_LIST = nltk.corpus.words.words() +def request_parser(request): + data = request + if request.get("input") is not None: + data = request.get("input") + return data + def completions_benchmark_generator() -> dict: prompt = " ".join(random.choices(WORD_LIST, k=int(250))) @@ -55,6 +61,7 @@ worker_config = WorkerConfig( workload_calculator= lambda data: data.get("max_tokens", 0), allow_parallel_requests=True, max_queue_time=60.0, + request_parser=request_parser, benchmark_config=BenchmarkConfig( generator=completions_benchmark_generator, concurrency=100, @@ -66,6 +73,7 @@ worker_config = WorkerConfig( workload_calculator= lambda data: data.get("max_tokens", 0), allow_parallel_requests=True, max_queue_time=60.0, + request_parser=request_parser ) ], log_action_config=LogActionConfig(