diff --git a/requirements.txt b/requirements.txt index 04b5600..08843aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ vastai-sdk -nltk diff --git a/workers/openai/worker.py b/workers/openai/worker.py index 241a59a..4254f33 100644 --- a/workers/openai/worker.py +++ b/workers/openai/worker.py @@ -1,4 +1,3 @@ -import nltk import random import os @@ -21,9 +20,6 @@ MODEL_ERROR_LOG_MSGS = [ MODEL_INFO_LOG_MSGS = [ ] -nltk.download("words") -WORD_LIST = nltk.corpus.words.words() - def request_parser(request): data = request if request.get("input") is not None: @@ -32,6 +28,14 @@ def request_parser(request): def completions_benchmark_generator() -> dict: + # extract words from the python source code of the worker to create a list of words for generating prompts + + WORD_LIST = [] + + with open(__file__, 'r') as f: + for line in f: + WORD_LIST.extend(line.strip().split()) + prompt = " ".join(random.choices(WORD_LIST, k=int(250))) model = os.environ.get("MODEL_NAME") if not model: