Skip to content

LLMs

Wrapper for interacting with external and hosted large language models (LLMs).

Classes:

  • CachedLLM

    A caching wrapper for LLM implementations that stores and retrieves previous LLM responses.

  • LLM

    An abstract base class representing a Language Learning Model (LLM).

  • LLMEvent

    An event class representing either a chunk of LLM output or the final LLM output.

  • LLMStream

    A wrapper class for LLM generators that provides convenient iteration and output extraction.

  • LiteLLM

    A LiteLLM implementation of the LLM interface.

  • MockLLM

    A mock LLM implementation for testing purposes.

  • ReplayLLM

    Specialized LLM class that replays previously recorded LLM interactions.

  • TrainableLLM

    Class for interacting with trainable language models through OpenAI-compatible API endpoints.

Functions:

  • closest_prompt

    Finds the closest matching prompt from a list of known prompts based on a Levenshtein similarity ratio.

  • trainable_llm_make_training_text

    Generates training text for LLM fine-tuning by combining prompt and output using tokenizer's chat template.

CachedLLM

Bases: LLM

A caching wrapper for LLM implementations that stores and retrieves previous LLM responses.

This class implements caching functionality for LLM responses to avoid redundant API calls and enable replay of previous interactions. It supports both file-based caching and SQLite-based replay functionality for testing purposes.

Attributes:

  • use_cache (bool) –

    Flag to enable/disable caching functionality. Defaults to False.

  • stream (bool) –

    Flag to enable/disable streaming responses. Defaults to False.

  • _cache (dict) –

    Internal cache storage mapping prompt keys to LLM responses.

The cache can be initialized in two modes: 1. SQLite replay mode: Used for testing, enforces cache hits only 2. File-based cache mode: Stores responses in a jsonl file for persistence

Cache keys are generated based on the prompt content, excluding the prompt ID. During testing (replay mode), exact text matching is used instead of hashing.

Methods:

  • generate

    Generate a response stream from the language model based on the given prompt.

  • reindex_log

    Reindex the log data into cache.

Source code in tapeagents/llms.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
class CachedLLM(LLM):
    """A caching wrapper for LLM implementations that stores and retrieves previous LLM responses.

    This class implements caching functionality for LLM responses to avoid redundant API calls
    and enable replay of previous interactions. It supports both file-based caching and SQLite-based
    replay functionality for testing purposes.

    Attributes:
        use_cache (bool): Flag to enable/disable caching functionality. Defaults to False.
        stream (bool): Flag to enable/disable streaming responses. Defaults to False.
        _cache (dict): Internal cache storage mapping prompt keys to LLM responses.

    The cache can be initialized in two modes:
    1. SQLite replay mode: Used for testing, enforces cache hits only
    2. File-based cache mode: Stores responses in a jsonl file for persistence

    Cache keys are generated based on the prompt content, excluding the prompt ID.
    During testing (replay mode), exact text matching is used instead of hashing.
    """

    use_cache: bool = False
    stream: bool = False
    _cache: dict = {}

    def model_post_init(self, __content):
        if _REPLAY_SQLITE:
            self.use_cache = True
            self._cache = {}
            llm_calls = retrieve_all_llm_calls(_REPLAY_SQLITE)
            for llm_call in llm_calls:
                key = self.get_prompt_key(llm_call.prompt)
                self._cache[key] = [LLMEvent(output=llm_call.output)]
            logger.info(f"Enforced LLM cache from {_REPLAY_SQLITE}, {len(self._cache)} entries")
            return
        elif not self.use_cache:
            return
        logger.info("Use LLM Cache")
        param_hash = self._key(json.dumps({k: v for k, v in self.parameters.items() if k != "token"}))
        name = self.model_name.replace("/", "__")
        self._cache_file = f"llm_cache_{name}_{param_hash}.jsonl"
        if os.path.exists(self._cache_file):
            with open(self._cache_file) as f:
                for line in f:
                    key, event_dict = json.loads(line)
                    if key not in self._cache:
                        self._cache[key] = []
                    self._cache[key].append(event_dict)
            logger.info(f"Loaded cache with {len(self._cache)} keys")
        else:
            logger.info("Cache file not found")

    def reindex_log(self):
        """
        Reindex the log data into cache.

        This method iterates through the log entries, validates each prompt and output,
        and adds them to the cache using the prompt key as index. Each entry is converted
        to an LLMEvent model before caching.

        Side Effects:
            - Updates the internal cache with log data
            - Logs the total number of reindexed entries at INFO level
        """
        cnt = 0
        for log_data in self._log:
            key = self.get_prompt_key(Prompt.model_validate(log_data["prompt"]))
            self._add_to_cache(key, LLMEvent(output=LLMOutput.model_validate(log_data["output"])).model_dump())
            cnt += 1
        logger.info(f"Reindexed {cnt} log entries")

    def _add_to_cache(self, key: str, event_dict: dict):
        if not self.use_cache:
            return
        if key not in self._cache:
            self._cache[key] = []
        self._cache[key].append(event_dict)
        with open(self._cache_file, "a") as f:
            f.write(json.dumps((key, event_dict), ensure_ascii=False) + "\n")

    def get_prompt_key(self, prompt: Prompt) -> str:
        prompt_text = json.dumps(prompt.model_dump(exclude={"id"}), ensure_ascii=False, sort_keys=True)
        return self._key(prompt_text)

    def _key(self, text: str) -> str:
        if _REPLAY_SQLITE:
            # use exact text as a key during testing
            return text
        return hashlib.md5(text.encode("utf-8")).hexdigest()

    def generate(self, prompt: Prompt, **kwargs) -> LLMStream:
        """Generate a response stream from the language model based on the given prompt.

        This method handles both cached and new responses, implementing a caching mechanism
        for LLM responses to avoid redundant API calls.

        Args:
            prompt (Prompt): The prompt object containing messages to send to the LLM.
            **kwargs (dict, optional): Additional arguments to pass to the underlying LLM implementation.

        Returns:
            LLMStream: A stream of LLM events containing the model's response.

        Raises:
            ValueError: If cache miss occurs when replay mode is enabled (_REPLAY_SQLITE is True).

        Notes:
            - If caching is enabled and the prompt exists in cache, returns cached response
            - If generating new response, tokens are counted and added to total token count
            - All generated events are cached for future use if caching is enabled
            - Output is logged through the logging system
        """

        def _implementation():
            key = self.get_prompt_key(prompt)
            if self.use_cache and key in self._cache:
                logger.debug(colored(f"llm cache hit, {len(self._cache[key])} events", "green"))
                for event_dict in self._cache[key]:
                    event = LLMEvent.model_validate(event_dict)
                    if event.output is not None:
                        self.log_output(prompt, event.output, cached=True)
                    yield event
            else:
                if _REPLAY_SQLITE:
                    closest, score = closest_prompt(key, list(self._cache.keys()))
                    logger.error(
                        f"llm cache miss, closest in cache has score {score:.3f}\nDIFF:\n{diff_strings(key, closest)}"
                    )
                    raise ValueError(f"llm cache miss not allowed, prompt: {key}")
                toks = self.count_tokens(prompt.messages)
                self.token_count += toks
                logger.debug(f"{toks} prompt tokens, total: {self.token_count}")
                for event in self._generate(prompt, **kwargs):
                    self._add_to_cache(key, event.model_dump())
                    # note: the underlying LLM will log the output
                    yield event

        return LLMStream(_implementation(), prompt)

    @abstractmethod
    def _generate(self, prompt: Prompt, **kwargs) -> Generator[LLMEvent, None, None]:
        pass

generate(prompt, **kwargs)

Generate a response stream from the language model based on the given prompt.

This method handles both cached and new responses, implementing a caching mechanism for LLM responses to avoid redundant API calls.

Parameters:

  • prompt (Prompt) –

    The prompt object containing messages to send to the LLM.

  • **kwargs (dict, default: {} ) –

    Additional arguments to pass to the underlying LLM implementation.

Returns:

  • LLMStream ( LLMStream ) –

    A stream of LLM events containing the model's response.

Raises:

  • ValueError

    If cache miss occurs when replay mode is enabled (_REPLAY_SQLITE is True).

Notes
  • If caching is enabled and the prompt exists in cache, returns cached response
  • If generating new response, tokens are counted and added to total token count
  • All generated events are cached for future use if caching is enabled
  • Output is logged through the logging system
Source code in tapeagents/llms.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def generate(self, prompt: Prompt, **kwargs) -> LLMStream:
    """Generate a response stream from the language model based on the given prompt.

    This method handles both cached and new responses, implementing a caching mechanism
    for LLM responses to avoid redundant API calls.

    Args:
        prompt (Prompt): The prompt object containing messages to send to the LLM.
        **kwargs (dict, optional): Additional arguments to pass to the underlying LLM implementation.

    Returns:
        LLMStream: A stream of LLM events containing the model's response.

    Raises:
        ValueError: If cache miss occurs when replay mode is enabled (_REPLAY_SQLITE is True).

    Notes:
        - If caching is enabled and the prompt exists in cache, returns cached response
        - If generating new response, tokens are counted and added to total token count
        - All generated events are cached for future use if caching is enabled
        - Output is logged through the logging system
    """

    def _implementation():
        key = self.get_prompt_key(prompt)
        if self.use_cache and key in self._cache:
            logger.debug(colored(f"llm cache hit, {len(self._cache[key])} events", "green"))
            for event_dict in self._cache[key]:
                event = LLMEvent.model_validate(event_dict)
                if event.output is not None:
                    self.log_output(prompt, event.output, cached=True)
                yield event
        else:
            if _REPLAY_SQLITE:
                closest, score = closest_prompt(key, list(self._cache.keys()))
                logger.error(
                    f"llm cache miss, closest in cache has score {score:.3f}\nDIFF:\n{diff_strings(key, closest)}"
                )
                raise ValueError(f"llm cache miss not allowed, prompt: {key}")
            toks = self.count_tokens(prompt.messages)
            self.token_count += toks
            logger.debug(f"{toks} prompt tokens, total: {self.token_count}")
            for event in self._generate(prompt, **kwargs):
                self._add_to_cache(key, event.model_dump())
                # note: the underlying LLM will log the output
                yield event

    return LLMStream(_implementation(), prompt)

reindex_log()

Reindex the log data into cache.

This method iterates through the log entries, validates each prompt and output, and adds them to the cache using the prompt key as index. Each entry is converted to an LLMEvent model before caching.

Side Effects
  • Updates the internal cache with log data
  • Logs the total number of reindexed entries at INFO level
Source code in tapeagents/llms.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def reindex_log(self):
    """
    Reindex the log data into cache.

    This method iterates through the log entries, validates each prompt and output,
    and adds them to the cache using the prompt key as index. Each entry is converted
    to an LLMEvent model before caching.

    Side Effects:
        - Updates the internal cache with log data
        - Logs the total number of reindexed entries at INFO level
    """
    cnt = 0
    for log_data in self._log:
        key = self.get_prompt_key(Prompt.model_validate(log_data["prompt"]))
        self._add_to_cache(key, LLMEvent(output=LLMOutput.model_validate(log_data["output"])).model_dump())
        cnt += 1
    logger.info(f"Reindexed {cnt} log entries")

LLM

Bases: BaseModel, ABC

An abstract base class representing a Language Learning Model (LLM).

This class defines the interface for interacting with different LLM implementations. It handles basic LLM functionality like token counting, generation, and logging.

Attributes:

  • model_name (str) –

    Name of the LLM model

  • parameters (dict) –

    Model-specific parameters for generation

  • context_size (int) –

    Maximum context size in tokens (default: 32000)

  • tokenizer_name (str) –

    Name of the tokenizer used

  • tokenizer (Any) –

    Tokenizer instance

  • token_count (int) –

    Running count of tokens processed

  • _log (list) –

    Internal log of LLM calls

Note

This is an abstract class and requires implementation of the abstract methods in derived classes.

Methods:

  • count_tokens

    Count tokens in messages or text

  • generate

    Generate text from a given prompt

  • log_output

    Logs the output of an LLM (Language Model) call along with its metadata.

  • make_training_text

    Create training text from prompt and output.

Source code in tapeagents/llms.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class LLM(BaseModel, ABC):
    """
    An abstract base class representing a Language Learning Model (LLM).

    This class defines the interface for interacting with different LLM implementations.
    It handles basic LLM functionality like token counting, generation, and logging.

    Attributes:
        model_name (str): Name of the LLM model
        parameters (dict): Model-specific parameters for generation
        context_size (int): Maximum context size in tokens (default: 32000)
        tokenizer_name (str): Name of the tokenizer used
        tokenizer (Any): Tokenizer instance
        token_count (int): Running count of tokens processed
        _log (list): Internal log of LLM calls

    Note:
        This is an abstract class and requires implementation of the abstract methods
        in derived classes.
    """

    model_name: str
    parameters: dict = {}
    context_size: int = 32000
    tokenizer_name: str = ""
    tokenizer: Any = None

    token_count: int = 0
    _log: list = []

    @abstractmethod
    def generate(self, prompt: Prompt, **kwargs) -> LLMStream:
        """
        Generate text from a given prompt

        Args:
            prompt (Prompt): The prompt object containing messages to send to the LLM.
            **kwargs (dict, optional): Additional arguments to pass to the underlying LLM implementation.

        Returns:
            LLMStream: A stream of LLM events containing the model's response.
        """
        pass

    @abstractmethod
    def count_tokens(self, messages: list[dict] | str) -> int:
        """
        Count tokens in messages or text

        Args:
            messages (Union[List[Dict], str]): List of messages or text to count tokens in

        Returns:
            int: Number of tokens in the messages or text
        """
        pass

    @abstractmethod
    def make_training_text(self, prompt: Prompt, output: LLMOutput) -> TrainingText:
        """
        Create training text from prompt and output.

        Args:
            prompt (Prompt): The prompt object containing messages used to generate the output.
            output (LLMOutput): The output generated by the LLM.

        Returns:
            TrainingText: The training text object containing the prompt and output.
        """
        pass

    def log_output(self, prompt: Prompt, message: LLMOutput, cached: bool = False):
        """
        Logs the output of an LLM (Language Model) call along with its metadata.

        Args:
            prompt (Prompt): The prompt object containing the input messages for the LLM.
            message (LLMOutput): The output message generated by the LLM.
            cached (bool, optional): Indicates whether the output was retrieved from cache. Defaults to False.
        """
        llm_call = LLMCall(
            timestamp=datetime.datetime.now().isoformat(),
            prompt=prompt,
            output=message,
            prompt_length_tokens=self.count_tokens(prompt.messages),
            output_length_tokens=self.count_tokens(message.content) if message.content else 0,
            cached=cached,
        )
        self._log.append(llm_call.model_dump())
        observe_llm_call(llm_call)

count_tokens(messages) abstractmethod

Count tokens in messages or text

Parameters:

  • messages (Union[List[Dict], str]) –

    List of messages or text to count tokens in

Returns:

  • int ( int ) –

    Number of tokens in the messages or text

Source code in tapeagents/llms.py
148
149
150
151
152
153
154
155
156
157
158
159
@abstractmethod
def count_tokens(self, messages: list[dict] | str) -> int:
    """
    Count tokens in messages or text

    Args:
        messages (Union[List[Dict], str]): List of messages or text to count tokens in

    Returns:
        int: Number of tokens in the messages or text
    """
    pass

generate(prompt, **kwargs) abstractmethod

Generate text from a given prompt

Parameters:

  • prompt (Prompt) –

    The prompt object containing messages to send to the LLM.

  • **kwargs (dict, default: {} ) –

    Additional arguments to pass to the underlying LLM implementation.

Returns:

  • LLMStream ( LLMStream ) –

    A stream of LLM events containing the model's response.

Source code in tapeagents/llms.py
134
135
136
137
138
139
140
141
142
143
144
145
146
@abstractmethod
def generate(self, prompt: Prompt, **kwargs) -> LLMStream:
    """
    Generate text from a given prompt

    Args:
        prompt (Prompt): The prompt object containing messages to send to the LLM.
        **kwargs (dict, optional): Additional arguments to pass to the underlying LLM implementation.

    Returns:
        LLMStream: A stream of LLM events containing the model's response.
    """
    pass

log_output(prompt, message, cached=False)

Logs the output of an LLM (Language Model) call along with its metadata.

Parameters:

  • prompt (Prompt) –

    The prompt object containing the input messages for the LLM.

  • message (LLMOutput) –

    The output message generated by the LLM.

  • cached (bool, default: False ) –

    Indicates whether the output was retrieved from cache. Defaults to False.

Source code in tapeagents/llms.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def log_output(self, prompt: Prompt, message: LLMOutput, cached: bool = False):
    """
    Logs the output of an LLM (Language Model) call along with its metadata.

    Args:
        prompt (Prompt): The prompt object containing the input messages for the LLM.
        message (LLMOutput): The output message generated by the LLM.
        cached (bool, optional): Indicates whether the output was retrieved from cache. Defaults to False.
    """
    llm_call = LLMCall(
        timestamp=datetime.datetime.now().isoformat(),
        prompt=prompt,
        output=message,
        prompt_length_tokens=self.count_tokens(prompt.messages),
        output_length_tokens=self.count_tokens(message.content) if message.content else 0,
        cached=cached,
    )
    self._log.append(llm_call.model_dump())
    observe_llm_call(llm_call)

make_training_text(prompt, output) abstractmethod

Create training text from prompt and output.

Parameters:

  • prompt (Prompt) –

    The prompt object containing messages used to generate the output.

  • output (LLMOutput) –

    The output generated by the LLM.

Returns:

  • TrainingText ( TrainingText ) –

    The training text object containing the prompt and output.

Source code in tapeagents/llms.py
161
162
163
164
165
166
167
168
169
170
171
172
173
@abstractmethod
def make_training_text(self, prompt: Prompt, output: LLMOutput) -> TrainingText:
    """
    Create training text from prompt and output.

    Args:
        prompt (Prompt): The prompt object containing messages used to generate the output.
        output (LLMOutput): The output generated by the LLM.

    Returns:
        TrainingText: The training text object containing the prompt and output.
    """
    pass

LLMEvent

Bases: BaseModel

An event class representing either a chunk of LLM output or the final LLM output.

This class encapsulates events that occur during LLM processing, handling both intermediate chunks of output and the final complete output.

Attributes:

  • chunk (str) –

    A partial text output from the LLM stream. None if this event represents a complete output.

  • output (LLMOutput) –

    The complete output from the LLM. None if this event represents a partial chunk.

Source code in tapeagents/llms.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class LLMEvent(BaseModel):
    """An event class representing either a chunk of LLM output or the final LLM output.

    This class encapsulates events that occur during LLM processing, handling both
    intermediate chunks of output and the final complete output.

    Attributes:
        chunk (str, optional): A partial text output from the LLM stream. None if this
            event represents a complete output.
        output (LLMOutput, optional): The complete output from the LLM. None if this
            event represents a partial chunk.
    """

    chunk: str | None = None
    output: LLMOutput | None = None

LLMStream

A wrapper class for LLM generators that provides convenient iteration and output extraction.

This class wraps a generator that yields LLMEvents and provides methods to:

  • Iterate through events
  • Extract complete LLM output
  • Get the assistant's response text

Attributes:

  • generator

    Generator yielding LLMEvents or None if empty

  • prompt

    The prompt used to generate the LLM response:

Raises:

  • ValueError

    When trying to iterate null stream, when no output is produced, or when output is not an assistant message with content

Methods:

  • get_output

    Returns first LLMOutput found in events

  • get_text

    Returns content of first assistant message found

Source code in tapeagents/llms.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
class LLMStream:
    """A wrapper class for LLM generators that provides convenient iteration and output extraction.

    This class wraps a generator that yields LLMEvents and provides methods to:

    - Iterate through events
    - Extract complete LLM output
    - Get the assistant's response text

    Attributes:
        generator: Generator yielding LLMEvents or None if empty
        prompt: The prompt used to generate the LLM response:

    Raises:
        ValueError: When trying to iterate null stream, when no output is produced,
                   or when output is not an assistant message with content
    """

    def __init__(self, generator: Generator[LLMEvent, None, None] | None, prompt: Prompt):
        self.generator = generator
        self.prompt = prompt

    def __bool__(self):
        return self.generator is not None

    def __iter__(self):
        if self.generator is None:
            raise ValueError("can't iterate a null stream")
        return self

    def __next__(self) -> LLMEvent:
        if self.generator is None:
            raise StopIteration
        return next(self.generator)

    def get_output(self) -> LLMOutput:
        """Returns first LLMOutput found in events"""
        for event in self:
            if event.output:
                return event.output
        raise ValueError("LLM did not produce an output")

    def get_text(self) -> str:
        """Returns content of first assistant message found"""
        o = self.get_output()
        if not o.role == "assistant" or o.content is None:
            raise ValueError("LLM did not produce an assistant message")
        return o.content

get_output()

Returns first LLMOutput found in events

Source code in tapeagents/llms.py
89
90
91
92
93
94
def get_output(self) -> LLMOutput:
    """Returns first LLMOutput found in events"""
    for event in self:
        if event.output:
            return event.output
    raise ValueError("LLM did not produce an output")

get_text()

Returns content of first assistant message found

Source code in tapeagents/llms.py
 96
 97
 98
 99
100
101
def get_text(self) -> str:
    """Returns content of first assistant message found"""
    o = self.get_output()
    if not o.role == "assistant" or o.content is None:
        raise ValueError("LLM did not produce an assistant message")
    return o.content

LiteLLM

Bases: CachedLLM

A LiteLLM implementation of the LLM interface.

This class provides integration with the LiteLLM library for making LLM API calls. It supports both streaming and non-streaming responses, token counting, and handles API timeouts with retries. Streaming responses are handled by yielding chunks of text as they arrive. Non-streaming responses return complete messages.

Note

Function calling during streaming is not yet implemented and will raise NotImplementedError.

Methods:

Source code in tapeagents/llms.py
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
class LiteLLM(CachedLLM):
    """A LiteLLM implementation of the LLM interface.

    This class provides integration with the LiteLLM library for making LLM API calls.
    It supports both streaming and non-streaming responses, token counting, and handles API timeouts with retries.
    Streaming responses are handled by yielding chunks of text as they arrive.
    Non-streaming responses return complete messages.

    Note:
        Function calling during streaming is not yet implemented and will raise NotImplementedError.
    """

    def count_tokens(self, messages: list[dict] | str) -> int:
        """
        Count the number of tokens in a message or string.

        Args:
            messages (Union[List[Dict], str]): List of messages or text to count tokens in.

        Returns:
            int: The number of tokens in the messages or text.
        """
        if isinstance(messages, str):
            return litellm.token_counter(model=self.model_name, text=messages)
        else:
            return litellm.token_counter(model=self.model_name, messages=messages)

    def _generate(self, prompt: Prompt, **kwargs) -> Generator[LLMEvent, None, None]:
        while True:
            try:
                response = litellm.completion(
                    model=self.model_name,
                    messages=prompt.messages,
                    tools=prompt.tools,
                    stream=self.stream,
                    **self.parameters,
                )
                break
            except openai.APITimeoutError:
                logger.error("API Timeout, retrying in 1 sec")
                time.sleep(1.0)
        if self.stream:
            buffer = []
            for part in response:
                assert isinstance(part, litellm.ModelResponse)
                if isinstance(part.choices[0], litellm.utils.StreamingChoices):
                    content_delta = part.choices[0].delta.content
                    if content_delta:
                        buffer.append(content_delta)
                        yield LLMEvent(chunk=content_delta)
                    tool_delta = part.choices[0].delta.tool_calls
                    if tool_delta:
                        raise NotImplementedError("TODO: streaming with function calls not implemented yet")
                else:
                    raise ValueError(f"Unexpected response {part.model_dump()}")
            output = LLMOutput(content="".join(buffer))
        else:
            assert isinstance(response, litellm.ModelResponse)
            assert isinstance(response.choices[0], litellm.utils.Choices)
            output = response.choices[0].message
        self.log_output(prompt, output)
        yield LLMEvent(output=output)

    def make_training_text(self, *args, **kwargs) -> TrainingText:
        """
        Generates the training text for the model.

        This method should be implemented by subclasses to provide the specific
        logic for creating the training text.

        Args:
            *args (list): Variable length argument list.
            **kwargs (dict, optional): Arbitrary keyword arguments.

        Returns:
            TrainingText: The generated training text.

        Raises:
            NotImplementedError: If the method is not implemented by a subclass.
        """
        raise NotImplementedError()

count_tokens(messages)

Count the number of tokens in a message or string.

Parameters:

  • messages (Union[List[Dict], str]) –

    List of messages or text to count tokens in.

Returns:

  • int ( int ) –

    The number of tokens in the messages or text.

Source code in tapeagents/llms.py
362
363
364
365
366
367
368
369
370
371
372
373
374
375
def count_tokens(self, messages: list[dict] | str) -> int:
    """
    Count the number of tokens in a message or string.

    Args:
        messages (Union[List[Dict], str]): List of messages or text to count tokens in.

    Returns:
        int: The number of tokens in the messages or text.
    """
    if isinstance(messages, str):
        return litellm.token_counter(model=self.model_name, text=messages)
    else:
        return litellm.token_counter(model=self.model_name, messages=messages)

make_training_text(*args, **kwargs)

Generates the training text for the model.

This method should be implemented by subclasses to provide the specific logic for creating the training text.

Parameters:

  • *args (list, default: () ) –

    Variable length argument list.

  • **kwargs (dict, default: {} ) –

    Arbitrary keyword arguments.

Returns:

  • TrainingText ( TrainingText ) –

    The generated training text.

Raises:

  • NotImplementedError

    If the method is not implemented by a subclass.

Source code in tapeagents/llms.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
def make_training_text(self, *args, **kwargs) -> TrainingText:
    """
    Generates the training text for the model.

    This method should be implemented by subclasses to provide the specific
    logic for creating the training text.

    Args:
        *args (list): Variable length argument list.
        **kwargs (dict, optional): Arbitrary keyword arguments.

    Returns:
        TrainingText: The generated training text.

    Raises:
        NotImplementedError: If the method is not implemented by a subclass.
    """
    raise NotImplementedError()

MockLLM

Bases: LLM

A mock LLM implementation for testing purposes.

This class simulates an LLM by returning predefined responses in a cyclic manner. It tracks the prompts it receives and maintains a call counter.

Attributes:

  • model_name (str) –

    Name of the mock model, defaults to "mock"

  • call_number (int) –

    Counter for number of calls made to generate, defaults to 0

  • mock_outputs (list[str]) –

    List of predefined responses to cycle through

  • prompts (list[Prompt]) –

    List of received prompts

Source code in tapeagents/llms.py
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
class MockLLM(LLM):
    """A mock LLM implementation for testing purposes.

    This class simulates an LLM by returning predefined responses in a cyclic manner.
    It tracks the prompts it receives and maintains a call counter.

    Attributes:
        model_name (str): Name of the mock model, defaults to "mock"
        call_number (int): Counter for number of calls made to generate, defaults to 0
        mock_outputs (list[str]): List of predefined responses to cycle through
        prompts (list[Prompt]): List of received prompts
    """

    model_name: str = "mock"
    call_number: int = 0
    mock_outputs: list[str] = [
        "Agent: I'm good, thank you",
        "Agent: Sure, I worked at ServiceNow for 10 years",
        "Agent: I have 10 zillion parameters",
    ]
    prompts: list[Prompt] = []

    def generate(self, prompt: Prompt) -> LLMStream:
        def _implementation():
            self.prompts.append(prompt)
            output = self.mock_outputs[self.call_number % len(self.mock_outputs)]
            time.sleep(0.01)
            yield LLMEvent(output=LLMOutput(content=output))
            self.call_number += 1

        return LLMStream(_implementation(), prompt=prompt)

    def count_tokens(self, messages: list[dict] | str) -> int:
        return 42

    def make_training_text(self, prompt: Prompt, output: LLMOutput) -> TrainingText:
        return TrainingText(text="mock trace", n_predicted=10)

ReplayLLM

Bases: LLM

Specialized LLM class that replays previously recorded LLM interactions.

Loads and replays model interactions from a SQLite database, allowing for deterministic replay of previous LLM conversations without making new API calls.

The class is useful for:

  • Testing and debugging LLM interactions
  • Reproducing specific model behaviors
  • Avoiding repeated API calls during development
  • Creating deterministic test scenarios

Attributes:

  • outputs (dict[str, str]) –

    Dictionary mapping prompt strings to their recorded outputs

  • llm_calls (list[LLMCall]) –

    List of recorded LLM call objects

  • count_tokens_fn (Callable) –

    Function to count tokens in prompts/messages

  • make_training_text_fn (Callable) –

    Function to create training text from prompt/output pairs

Raises:

  • FatalError

    When a prompt is not found in the recorded outputs

  • AssertionError

    When the specified SQLite database file doesn't exist

Methods:

  • count_tokens

    Counts the number of tokens in the given messages.

  • from_llm

    Create a ReplayLLM instance from an existing LLM and a SQLite database file.

  • generate

    Generates an LLMStream based on the provided prompt.

  • make_training_text

    Generates training text based on the provided prompt and output.

Source code in tapeagents/llms.py
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
class ReplayLLM(LLM):
    """
    Specialized LLM class that replays previously recorded LLM interactions.

    Loads and replays model interactions from a SQLite database, allowing for
    deterministic replay of previous LLM conversations without making new API calls.

    The class is useful for:

    - Testing and debugging LLM interactions
    - Reproducing specific model behaviors
    - Avoiding repeated API calls during development
    - Creating deterministic test scenarios

    Attributes:
        outputs (dict[str, str]): Dictionary mapping prompt strings to their recorded outputs
        llm_calls (list[LLMCall]): List of recorded LLM call objects
        count_tokens_fn (Callable): Function to count tokens in prompts/messages
        make_training_text_fn (Callable): Function to create training text from prompt/output pairs

    Raises:
        FatalError: When a prompt is not found in the recorded outputs
        AssertionError: When the specified SQLite database file doesn't exist
    """

    outputs: dict[str, str] = Field(default_factory=dict)
    llm_calls: list[LLMCall]
    count_tokens_fn: Callable = lambda x: 0
    make_training_text_fn: Callable = lambda x, y: TrainingText(text="", n_predicted=0)

    @classmethod
    def from_llm(cls, llm: LLM, run_dir: str, prompts_file: str = DB_DEFAULT_FILENAME):
        """
        Create a ReplayLLM instance from an existing LLM and a SQLite database file.

        Args:
            cls (Type): The class to instantiate.
            llm (LLM): The original LLM instance.
            run_dir (str): The directory where the SQLite database file is located.
            prompts_file (str, optional): The name of the SQLite database file. Defaults to DB_DEFAULT_FILENAME.

        Returns:
            (ReplayLLM): An instance of ReplayLLM initialized with the LLM calls from the SQLite database.

        Raises:
            AssertionError: If the SQLite database file does not exist at the specified path.
        """
        sqlite_fpath = os.path.join(run_dir, prompts_file)
        assert os.path.exists(sqlite_fpath), f"Sqlite not found: {sqlite_fpath}"
        llm_calls = retrieve_all_llm_calls(sqlite_fpath)
        replay_llm = ReplayLLM(
            llm_calls=llm_calls,
            model_name=llm.tokenizer_name or llm.model_name,
            context_size=llm.context_size,
        )
        replay_llm.tokenizer = llm.tokenizer
        replay_llm.count_tokens_fn = llm.count_tokens
        replay_llm.make_training_text_fn = llm.make_training_text
        return replay_llm

    def model_post_init(self, __context: Any) -> None:
        dups = 0
        for llm_call in self.llm_calls:
            prompt_key = json.dumps(llm_call.prompt.messages, indent=2, ensure_ascii=False, sort_keys=True)
            output = llm_call.output.content or ""
            if prompt_key in self.outputs and output != self.outputs[prompt_key]:
                logger.debug(f"Output duplicate, using last value!\nOLD:{self.outputs[prompt_key]}\nNEW:{output}")
                dups += 1
            self.outputs[prompt_key] = output
        logger.info(f"Loaded {len(self.outputs)} outputs, {dups} duplicates")
        return super().model_post_init(__context)

    def generate(self, prompt: Prompt, **kwargs) -> LLMStream:
        """
        Generates an LLMStream based on the provided prompt.

        This method checks if the prompt has been previously processed and cached. If a cached output is found,
        it is returned. Otherwise, it attempts to find the closest known prompt and logs the differences. If no
        similar prompt is found, a FatalError is raised.

        Args:
            prompt (Prompt): The prompt object containing the messages to be processed.
            **kwargs (dict, optional): Additional keyword arguments.

        Returns:
            LLMStream: A stream of LLM events containing the generated output.

        Raises:
            FatalError: If the prompt is not found in the cache and no similar prompt is found.
        """

        def _implementation():
            prompt_key = json.dumps(prompt.messages, indent=2, ensure_ascii=False, sort_keys=True)
            if prompt_key in self.outputs:
                logger.debug(colored("prompt cache hit", "green"))
                output = self.outputs[prompt_key]
            else:
                logger.warning(
                    colored(f"prompt of size {len(prompt_key)} not found, checking similar ones..", "yellow")
                )
                known_prompts = list(self.outputs.keys())
                closest, score = closest_prompt(prompt_key, known_prompts)
                if score >= 0.7:
                    logger.warning(f"Closest prompt score {score:.3f}")
                    for i, (a, b) in enumerate(zip_longest(prompt.messages, json.loads(closest), fillvalue={})):
                        logger.warning(f"STEP{i}: {diff_strings(a.get('content', str(a)), b.get('content', str(b)))}\n")
                raise FatalError("prompt not found")
            yield LLMEvent(output=LLMOutput(content=output))

        return LLMStream(_implementation(), prompt=prompt)

    def make_training_text(self, prompt: Prompt, output: LLMOutput) -> TrainingText:
        """
        Generates training text based on the provided prompt and output.

        Args:
            prompt (Prompt): The input prompt to generate training text from.
            output (LLMOutput): The output generated by the language model.

        Returns:
            TrainingText: The generated training text.
        """
        return self.make_training_text_fn(prompt, output)

    def count_tokens(self, messages: list[dict] | str) -> int:
        """
        Counts the number of tokens in the given messages.

        Args:
            messages (Union[list[dict], str]): A list of message dictionaries or a single string message.

        Returns:
            int: The total number of tokens in the messages.
        """
        return self.count_tokens_fn(messages)

count_tokens(messages)

Counts the number of tokens in the given messages.

Parameters:

  • messages (Union[list[dict], str]) –

    A list of message dictionaries or a single string message.

Returns:

  • int ( int ) –

    The total number of tokens in the messages.

Source code in tapeagents/llms.py
852
853
854
855
856
857
858
859
860
861
862
def count_tokens(self, messages: list[dict] | str) -> int:
    """
    Counts the number of tokens in the given messages.

    Args:
        messages (Union[list[dict], str]): A list of message dictionaries or a single string message.

    Returns:
        int: The total number of tokens in the messages.
    """
    return self.count_tokens_fn(messages)

from_llm(llm, run_dir, prompts_file=DB_DEFAULT_FILENAME) classmethod

Create a ReplayLLM instance from an existing LLM and a SQLite database file.

Parameters:

  • cls (Type) –

    The class to instantiate.

  • llm (LLM) –

    The original LLM instance.

  • run_dir (str) –

    The directory where the SQLite database file is located.

  • prompts_file (str, default: DB_DEFAULT_FILENAME ) –

    The name of the SQLite database file. Defaults to DB_DEFAULT_FILENAME.

Returns:

  • ReplayLLM

    An instance of ReplayLLM initialized with the LLM calls from the SQLite database.

Raises:

  • AssertionError

    If the SQLite database file does not exist at the specified path.

Source code in tapeagents/llms.py
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
@classmethod
def from_llm(cls, llm: LLM, run_dir: str, prompts_file: str = DB_DEFAULT_FILENAME):
    """
    Create a ReplayLLM instance from an existing LLM and a SQLite database file.

    Args:
        cls (Type): The class to instantiate.
        llm (LLM): The original LLM instance.
        run_dir (str): The directory where the SQLite database file is located.
        prompts_file (str, optional): The name of the SQLite database file. Defaults to DB_DEFAULT_FILENAME.

    Returns:
        (ReplayLLM): An instance of ReplayLLM initialized with the LLM calls from the SQLite database.

    Raises:
        AssertionError: If the SQLite database file does not exist at the specified path.
    """
    sqlite_fpath = os.path.join(run_dir, prompts_file)
    assert os.path.exists(sqlite_fpath), f"Sqlite not found: {sqlite_fpath}"
    llm_calls = retrieve_all_llm_calls(sqlite_fpath)
    replay_llm = ReplayLLM(
        llm_calls=llm_calls,
        model_name=llm.tokenizer_name or llm.model_name,
        context_size=llm.context_size,
    )
    replay_llm.tokenizer = llm.tokenizer
    replay_llm.count_tokens_fn = llm.count_tokens
    replay_llm.make_training_text_fn = llm.make_training_text
    return replay_llm

generate(prompt, **kwargs)

Generates an LLMStream based on the provided prompt.

This method checks if the prompt has been previously processed and cached. If a cached output is found, it is returned. Otherwise, it attempts to find the closest known prompt and logs the differences. If no similar prompt is found, a FatalError is raised.

Parameters:

  • prompt (Prompt) –

    The prompt object containing the messages to be processed.

  • **kwargs (dict, default: {} ) –

    Additional keyword arguments.

Returns:

  • LLMStream ( LLMStream ) –

    A stream of LLM events containing the generated output.

Raises:

  • FatalError

    If the prompt is not found in the cache and no similar prompt is found.

Source code in tapeagents/llms.py
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
def generate(self, prompt: Prompt, **kwargs) -> LLMStream:
    """
    Generates an LLMStream based on the provided prompt.

    This method checks if the prompt has been previously processed and cached. If a cached output is found,
    it is returned. Otherwise, it attempts to find the closest known prompt and logs the differences. If no
    similar prompt is found, a FatalError is raised.

    Args:
        prompt (Prompt): The prompt object containing the messages to be processed.
        **kwargs (dict, optional): Additional keyword arguments.

    Returns:
        LLMStream: A stream of LLM events containing the generated output.

    Raises:
        FatalError: If the prompt is not found in the cache and no similar prompt is found.
    """

    def _implementation():
        prompt_key = json.dumps(prompt.messages, indent=2, ensure_ascii=False, sort_keys=True)
        if prompt_key in self.outputs:
            logger.debug(colored("prompt cache hit", "green"))
            output = self.outputs[prompt_key]
        else:
            logger.warning(
                colored(f"prompt of size {len(prompt_key)} not found, checking similar ones..", "yellow")
            )
            known_prompts = list(self.outputs.keys())
            closest, score = closest_prompt(prompt_key, known_prompts)
            if score >= 0.7:
                logger.warning(f"Closest prompt score {score:.3f}")
                for i, (a, b) in enumerate(zip_longest(prompt.messages, json.loads(closest), fillvalue={})):
                    logger.warning(f"STEP{i}: {diff_strings(a.get('content', str(a)), b.get('content', str(b)))}\n")
            raise FatalError("prompt not found")
        yield LLMEvent(output=LLMOutput(content=output))

    return LLMStream(_implementation(), prompt=prompt)

make_training_text(prompt, output)

Generates training text based on the provided prompt and output.

Parameters:

  • prompt (Prompt) –

    The input prompt to generate training text from.

  • output (LLMOutput) –

    The output generated by the language model.

Returns:

  • TrainingText ( TrainingText ) –

    The generated training text.

Source code in tapeagents/llms.py
839
840
841
842
843
844
845
846
847
848
849
850
def make_training_text(self, prompt: Prompt, output: LLMOutput) -> TrainingText:
    """
    Generates training text based on the provided prompt and output.

    Args:
        prompt (Prompt): The input prompt to generate training text from.
        output (LLMOutput): The output generated by the language model.

    Returns:
        TrainingText: The generated training text.
    """
    return self.make_training_text_fn(prompt, output)

TrainableLLM

Bases: CachedLLM

Class for interacting with trainable language models through OpenAI-compatible API endpoints.

This class implements functionality for both inference and training-related operations with language models served via Text Generation Inference (TGI) or vLLM endpoints that expose an OpenAI-compatible API interface. It supports both streaming and non-streaming modes, and includes methods for token counting and log probability calculations.

Attributes:

  • base_url (str) –

    Base URL of the API endpoint

  • api_token (str) –

    Authentication token for API access

Methods:

  • count_tokens

    Count the number of tokens in the given messages.

  • get_log_probs

    Calculate the log probabilities of the given output based on the provided prompt.

  • get_log_probs_chat_complete

    Calculate the log probabilities of the tokens in the completion generated by the language model.

  • get_log_probs_complete

    Get the log probabilities of the tokens in the output given the prompt.

  • load_tokenizer

    Loads the tokenizer for the model.

  • make_training_text

    Generates training text from a given prompt and LLM output.

Source code in tapeagents/llms.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
class TrainableLLM(CachedLLM):
    """
    Class for interacting with trainable language models through OpenAI-compatible API endpoints.

    This class implements functionality for both inference and training-related operations with
    language models served via Text Generation Inference (TGI) or vLLM endpoints that expose
    an OpenAI-compatible API interface. It supports both streaming and non-streaming modes,
    and includes methods for token counting and log probability calculations.

    Attributes:
        base_url (str): Base URL of the API endpoint
        api_token (str): Authentication token for API access
    """

    # TODO: use OpenAI Python client when the certificate issue is resolved.
    # TODO: consider using litellm

    base_url: str
    api_token: str = Field(default="", exclude=True)

    def model_post_init(self, __context):
        super().model_post_init(__context)
        self.api_token = os.getenv(TAPEAGENTS_LLM_TOKEN, "")

    @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2))
    def _generate(self, prompt: Prompt) -> Generator[LLMEvent, None, None]:
        headers = {"Content-Type": "application/json"}
        if self.api_token:
            headers |= {"Authorization": f"Bearer {self.api_token}"}
        data = {
            "model": self.model_name,
            "messages": prompt.messages,
            "stream": self.stream,
        }
        r = requests.post(
            url=f"{self.base_url}/v1/chat/completions",
            json=data | self.parameters,
            headers=headers,
            stream=self.stream,
            verify=False,
        )
        if not r.ok:
            logger.error(f"Failed to get completion: {r.text}")
            r.raise_for_status()
        if self.stream:
            response_buffer = []
            for byte_payload in r.iter_lines():
                if byte_payload == b"\n":
                    continue
                payload = byte_payload.decode("utf-8")
                if payload.startswith("data:"):
                    if payload == "data: [DONE]":
                        continue
                    json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
                    response_delta = json_payload["choices"][0]["delta"].get("content", "")
                    if not response_delta:
                        continue
                    response_buffer.append(response_delta)
                    yield LLMEvent(chunk=response_delta)
            output = LLMOutput(content="".join(response_buffer))
        else:
            data = r.json()
            try:
                content = data["choices"][0]["message"]["content"]
                if not content:
                    logger.warning(f"Empty completion {data}")
                output = LLMOutput(content=content)
            except Exception as e:
                logger.exception(f"Failed to parse llm response: {r}")
                raise e
        self.log_output(prompt, output)
        yield LLMEvent(output=output)

    def load_tokenizer(self):
        """
        Loads the tokenizer for the model.

        If the tokenizer is not already loaded, this method will import the
        `transformers` library and load the tokenizer using the model name or
        tokenizer name. If `_MOCK_TOKENIZER` is set, it will use that instead.

        Raises:
            ValueError: If neither `self.tokenizer_name` nor `self.model_name`
                        is provided and `_MOCK_TOKENIZER` is not set.
        """
        if self.tokenizer is None:
            import transformers

            name = _MOCK_TOKENIZER if _MOCK_TOKENIZER else (self.tokenizer_name or self.model_name)
            self.tokenizer = transformers.AutoTokenizer.from_pretrained(name)

    def make_training_text(self, prompt: Prompt, output: LLMOutput) -> TrainingText:
        """
        Generates training text from a given prompt and LLM output.

        This method loads the tokenizer and uses it to create training text
        suitable for training a language model.

        Args:
            prompt (Prompt): The input prompt to generate training text from.
            output (LLMOutput): The output from the language model to be used in training.

        Returns:
            TrainingText: The generated training text.
        """
        self.load_tokenizer()
        return trainable_llm_make_training_text(prompt, output, self.tokenizer)

    def get_log_probs_complete(self, prompt: str, output: str) -> list[float]:
        """
        Get the log probabilities of the tokens in the output given the prompt.

        This method sends a request to the language model API to generate the log probabilities
        for the tokens in the provided output, given the prompt. It uses the tokenizer to encode
        the prompt and output, and extracts the log probabilities from the API response.

        Args:
            prompt (str): The input prompt text.
            output (str): The output text for which log probabilities are to be calculated.

        Returns:
            list[float]: A list of log probabilities for each token in the output.

        Raises:
            RuntimeError: If the API response is not as expected or if there is a mismatch
                          between the tokens in the response and the provided output.
        """
        if not self.tokenizer:
            self.load_tokenizer()

        headers = {"Content-Type": "application/json"}
        if self.api_token:
            headers |= {"Authorization": f"Bearer {self.api_token}"}

        if self.tokenizer.bos_token and prompt.startswith(self.tokenizer.bos_token):
            prompt = prompt[len(self.tokenizer.bos_token) :]

        prompt_text = prompt + output
        generation_args = {
            "model": self.model_name,
            "prompt": prompt_text,
            "temperature": 0.0,
            "max_tokens": 0,
            "logprobs": 1,
            "echo": True,
            "include_stop_str_in_output": True,  # self.include_stop_str_in_output,
            "skip_special_tokens": False,
            "n": 1,  # number of completions to generate
            "stream": False,  # return a single completion and not a stream of lines
        }
        url = f"{self.base_url}/v1/completions"
        logger.debug(f"POST request to {url}")
        r = requests.post(url, json=generation_args, headers=headers, verify=False)
        r.raise_for_status()  # raise exception if status code is not in the 200s
        try:
            response = r.json()
            log_probs = response["choices"][0]["logprobs"]["token_logprobs"]
            prompt_encoded = self.tokenizer.encode(prompt, add_special_tokens=True)
            prompt_completion_encoded = self.tokenizer.encode(prompt + output, add_special_tokens=True)
            log_probs = log_probs[len(prompt_encoded) : len(prompt_completion_encoded)]
            tokens = response["choices"][0]["logprobs"]["tokens"]
            tokens = tokens[len(prompt_encoded) : len(prompt_completion_encoded)]
            assert "".join(tokens) == output, f"Tokens do not match completion: {''.join(tokens)} != {output}"
        except Exception as e:
            raise RuntimeError(f"Generation API wrong response: {r.text}", e)
        return log_probs

    def get_log_probs_chat_complete(self, prompt: Prompt, output: LLMOutput) -> list[float]:
        """
        Calculate the log probabilities of the tokens in the completion generated by the language model.

        This function sends a request to the language model API to generate completions and calculate log probabilities.
        The function uses the tokenizer to encode the prompt and completion texts.
        The log probabilities are extracted from the API response and validated against the original completion.

        Args:
            prompt (Prompt): The prompt containing the messages to be sent to the language model.
            output (LLMOutput): The output from the language model containing the generated completion.

        Returns:
            list[float]: A list of log probabilities for each token in the generated completion.

        Raises:
            RuntimeError: If the response from the generation API is incorrect or cannot be parsed.
        """
        headers = {"Content-Type": "application/json"}
        if self.api_token:
            headers |= {"Authorization": f"Bearer {self.api_token}"}

        time_t0 = time.time()
        prompt_text = self.tokenizer.apply_chat_template(prompt.messages, tokenize=False)
        completion = output.content or ""
        messages = prompt.messages + [output.model_dump()]
        prompt_text = self.tokenizer.apply_chat_template(prompt.messages, tokenize=False, add_generation_prompt=True)
        prompt_completion_text = self.tokenizer.apply_chat_template(messages, tokenize=False)
        if self.tokenizer.bos_token and prompt_text.startswith(self.tokenizer.bos_token):
            prompt_text = prompt_text[len(self.tokenizer.bos_token) :]
            prompt_completion_text = prompt_completion_text[len(self.tokenizer.bos_token) :]

        prompt_encoded = self.tokenizer.encode(prompt_text, add_special_tokens=True)
        prompt_completion_encoded = self.tokenizer.encode(prompt_completion_text, add_special_tokens=True)

        generation_args = {
            "model": self.model_name,
            "messages": messages,
            "temperature": 0.0,
            "max_tokens": 1,
            "logprobs": 1,
            "echo": True,
            "include_stop_str_in_output": True,  # self.include_stop_str_in_output,
            "skip_special_tokens": False,
            "n": 1,  # number of completions to generate
            "stream": False,  # return a single completion and not a stream of lines
        }
        r = requests.post(
            url=f"{self.base_url}/v1/chat/completions",
            json=generation_args,
            headers=headers,
            verify=False,
        )
        r.raise_for_status()

        try:
            response = r.json()
            log_probs = []
            decoded_tokens = []
            for log_prob in response["prompt_logprobs"]:
                if log_prob:
                    token_key = next(iter(log_prob))
                    token_info = log_prob[token_key]
                    log_probs.append(token_info["logprob"])
                    decoded_tokens.append(token_info["decoded_token"])
                else:
                    log_probs.append(0.0)
                    decoded_tokens.append("")

            log_probs = log_probs[len(prompt_encoded) : len(prompt_completion_encoded)]
            decoded_tokens = decoded_tokens[len(prompt_encoded) : len(prompt_completion_encoded)]
            reconstructed_completion = "".join(decoded_tokens)
            if self.tokenizer.eos_token in reconstructed_completion:
                reconstructed_completion = reconstructed_completion[: -len(self.tokenizer.eos_token)]
            assert (
                reconstructed_completion == completion
            ), f"Tokens do not match completion: {reconstructed_completion} != {completion}"
        except Exception as e:
            raise RuntimeError(f"Generation API wrong response: {r.text}", e)

        logger.debug(f"Log likelihood calculation took {time.time() - time_t0:.2f} seconds")
        logger.debug(f"Tokens per second: {len(log_probs) / (time.time() - time_t0):.2f}")

        return log_probs

    def get_log_probs(self, prompt: str | Prompt, output: str | LLMOutput) -> list[float]:
        """
        Calculate the log probabilities of the given output based on the provided prompt.

        Args:
            prompt (Union[str, Prompt]): The input prompt, which can be either a string or a Prompt object.
            output (Union[str, LLMOutput]): The output to evaluate, which can be either a string or an LLMOutput object.

        Returns:
            list[float]: A list of log probabilities corresponding to the given output.

        Raises:
            ValueError: If the input types are not valid.
        """
        if isinstance(prompt, str) and isinstance(output, str):
            return self.get_log_probs_complete(prompt=prompt, output=output)
        elif isinstance(prompt, Prompt) and isinstance(output, LLMOutput):
            return self.get_log_probs_chat_complete(prompt=prompt, output=output)
        else:
            raise ValueError("Invalid input types")

    def count_tokens(self, messages: list[dict] | str) -> int:
        """
        Count the number of tokens in the given messages.

        This method loads the tokenizer and then counts the number of tokens
        in the provided messages. The messages can be either a string or a list
        of dictionaries.

        Args:
            messages (Union[list[dict], str]): The messages to count tokens for. It can
                               be a single string or a list of dictionaries.

        Returns:
            int: The number of tokens in the provided messages.
        """
        self.load_tokenizer()
        if isinstance(messages, str):
            return len(self.tokenizer(messages).input_ids)
        else:
            return len(self.tokenizer.apply_chat_template(messages))

count_tokens(messages)

Count the number of tokens in the given messages.

This method loads the tokenizer and then counts the number of tokens in the provided messages. The messages can be either a string or a list of dictionaries.

Parameters:

  • messages (Union[list[dict], str]) –

    The messages to count tokens for. It can be a single string or a list of dictionaries.

Returns:

  • int ( int ) –

    The number of tokens in the provided messages.

Source code in tapeagents/llms.py
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
def count_tokens(self, messages: list[dict] | str) -> int:
    """
    Count the number of tokens in the given messages.

    This method loads the tokenizer and then counts the number of tokens
    in the provided messages. The messages can be either a string or a list
    of dictionaries.

    Args:
        messages (Union[list[dict], str]): The messages to count tokens for. It can
                           be a single string or a list of dictionaries.

    Returns:
        int: The number of tokens in the provided messages.
    """
    self.load_tokenizer()
    if isinstance(messages, str):
        return len(self.tokenizer(messages).input_ids)
    else:
        return len(self.tokenizer.apply_chat_template(messages))

get_log_probs(prompt, output)

Calculate the log probabilities of the given output based on the provided prompt.

Parameters:

  • prompt (Union[str, Prompt]) –

    The input prompt, which can be either a string or a Prompt object.

  • output (Union[str, LLMOutput]) –

    The output to evaluate, which can be either a string or an LLMOutput object.

Returns:

  • list[float]

    list[float]: A list of log probabilities corresponding to the given output.

Raises:

  • ValueError

    If the input types are not valid.

Source code in tapeagents/llms.py
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
def get_log_probs(self, prompt: str | Prompt, output: str | LLMOutput) -> list[float]:
    """
    Calculate the log probabilities of the given output based on the provided prompt.

    Args:
        prompt (Union[str, Prompt]): The input prompt, which can be either a string or a Prompt object.
        output (Union[str, LLMOutput]): The output to evaluate, which can be either a string or an LLMOutput object.

    Returns:
        list[float]: A list of log probabilities corresponding to the given output.

    Raises:
        ValueError: If the input types are not valid.
    """
    if isinstance(prompt, str) and isinstance(output, str):
        return self.get_log_probs_complete(prompt=prompt, output=output)
    elif isinstance(prompt, Prompt) and isinstance(output, LLMOutput):
        return self.get_log_probs_chat_complete(prompt=prompt, output=output)
    else:
        raise ValueError("Invalid input types")

get_log_probs_chat_complete(prompt, output)

Calculate the log probabilities of the tokens in the completion generated by the language model.

This function sends a request to the language model API to generate completions and calculate log probabilities. The function uses the tokenizer to encode the prompt and completion texts. The log probabilities are extracted from the API response and validated against the original completion.

Parameters:

  • prompt (Prompt) –

    The prompt containing the messages to be sent to the language model.

  • output (LLMOutput) –

    The output from the language model containing the generated completion.

Returns:

  • list[float]

    list[float]: A list of log probabilities for each token in the generated completion.

Raises:

  • RuntimeError

    If the response from the generation API is incorrect or cannot be parsed.

Source code in tapeagents/llms.py
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
def get_log_probs_chat_complete(self, prompt: Prompt, output: LLMOutput) -> list[float]:
    """
    Calculate the log probabilities of the tokens in the completion generated by the language model.

    This function sends a request to the language model API to generate completions and calculate log probabilities.
    The function uses the tokenizer to encode the prompt and completion texts.
    The log probabilities are extracted from the API response and validated against the original completion.

    Args:
        prompt (Prompt): The prompt containing the messages to be sent to the language model.
        output (LLMOutput): The output from the language model containing the generated completion.

    Returns:
        list[float]: A list of log probabilities for each token in the generated completion.

    Raises:
        RuntimeError: If the response from the generation API is incorrect or cannot be parsed.
    """
    headers = {"Content-Type": "application/json"}
    if self.api_token:
        headers |= {"Authorization": f"Bearer {self.api_token}"}

    time_t0 = time.time()
    prompt_text = self.tokenizer.apply_chat_template(prompt.messages, tokenize=False)
    completion = output.content or ""
    messages = prompt.messages + [output.model_dump()]
    prompt_text = self.tokenizer.apply_chat_template(prompt.messages, tokenize=False, add_generation_prompt=True)
    prompt_completion_text = self.tokenizer.apply_chat_template(messages, tokenize=False)
    if self.tokenizer.bos_token and prompt_text.startswith(self.tokenizer.bos_token):
        prompt_text = prompt_text[len(self.tokenizer.bos_token) :]
        prompt_completion_text = prompt_completion_text[len(self.tokenizer.bos_token) :]

    prompt_encoded = self.tokenizer.encode(prompt_text, add_special_tokens=True)
    prompt_completion_encoded = self.tokenizer.encode(prompt_completion_text, add_special_tokens=True)

    generation_args = {
        "model": self.model_name,
        "messages": messages,
        "temperature": 0.0,
        "max_tokens": 1,
        "logprobs": 1,
        "echo": True,
        "include_stop_str_in_output": True,  # self.include_stop_str_in_output,
        "skip_special_tokens": False,
        "n": 1,  # number of completions to generate
        "stream": False,  # return a single completion and not a stream of lines
    }
    r = requests.post(
        url=f"{self.base_url}/v1/chat/completions",
        json=generation_args,
        headers=headers,
        verify=False,
    )
    r.raise_for_status()

    try:
        response = r.json()
        log_probs = []
        decoded_tokens = []
        for log_prob in response["prompt_logprobs"]:
            if log_prob:
                token_key = next(iter(log_prob))
                token_info = log_prob[token_key]
                log_probs.append(token_info["logprob"])
                decoded_tokens.append(token_info["decoded_token"])
            else:
                log_probs.append(0.0)
                decoded_tokens.append("")

        log_probs = log_probs[len(prompt_encoded) : len(prompt_completion_encoded)]
        decoded_tokens = decoded_tokens[len(prompt_encoded) : len(prompt_completion_encoded)]
        reconstructed_completion = "".join(decoded_tokens)
        if self.tokenizer.eos_token in reconstructed_completion:
            reconstructed_completion = reconstructed_completion[: -len(self.tokenizer.eos_token)]
        assert (
            reconstructed_completion == completion
        ), f"Tokens do not match completion: {reconstructed_completion} != {completion}"
    except Exception as e:
        raise RuntimeError(f"Generation API wrong response: {r.text}", e)

    logger.debug(f"Log likelihood calculation took {time.time() - time_t0:.2f} seconds")
    logger.debug(f"Tokens per second: {len(log_probs) / (time.time() - time_t0):.2f}")

    return log_probs

get_log_probs_complete(prompt, output)

Get the log probabilities of the tokens in the output given the prompt.

This method sends a request to the language model API to generate the log probabilities for the tokens in the provided output, given the prompt. It uses the tokenizer to encode the prompt and output, and extracts the log probabilities from the API response.

Parameters:

  • prompt (str) –

    The input prompt text.

  • output (str) –

    The output text for which log probabilities are to be calculated.

Returns:

  • list[float]

    list[float]: A list of log probabilities for each token in the output.

Raises:

  • RuntimeError

    If the API response is not as expected or if there is a mismatch between the tokens in the response and the provided output.

Source code in tapeagents/llms.py
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
def get_log_probs_complete(self, prompt: str, output: str) -> list[float]:
    """
    Get the log probabilities of the tokens in the output given the prompt.

    This method sends a request to the language model API to generate the log probabilities
    for the tokens in the provided output, given the prompt. It uses the tokenizer to encode
    the prompt and output, and extracts the log probabilities from the API response.

    Args:
        prompt (str): The input prompt text.
        output (str): The output text for which log probabilities are to be calculated.

    Returns:
        list[float]: A list of log probabilities for each token in the output.

    Raises:
        RuntimeError: If the API response is not as expected or if there is a mismatch
                      between the tokens in the response and the provided output.
    """
    if not self.tokenizer:
        self.load_tokenizer()

    headers = {"Content-Type": "application/json"}
    if self.api_token:
        headers |= {"Authorization": f"Bearer {self.api_token}"}

    if self.tokenizer.bos_token and prompt.startswith(self.tokenizer.bos_token):
        prompt = prompt[len(self.tokenizer.bos_token) :]

    prompt_text = prompt + output
    generation_args = {
        "model": self.model_name,
        "prompt": prompt_text,
        "temperature": 0.0,
        "max_tokens": 0,
        "logprobs": 1,
        "echo": True,
        "include_stop_str_in_output": True,  # self.include_stop_str_in_output,
        "skip_special_tokens": False,
        "n": 1,  # number of completions to generate
        "stream": False,  # return a single completion and not a stream of lines
    }
    url = f"{self.base_url}/v1/completions"
    logger.debug(f"POST request to {url}")
    r = requests.post(url, json=generation_args, headers=headers, verify=False)
    r.raise_for_status()  # raise exception if status code is not in the 200s
    try:
        response = r.json()
        log_probs = response["choices"][0]["logprobs"]["token_logprobs"]
        prompt_encoded = self.tokenizer.encode(prompt, add_special_tokens=True)
        prompt_completion_encoded = self.tokenizer.encode(prompt + output, add_special_tokens=True)
        log_probs = log_probs[len(prompt_encoded) : len(prompt_completion_encoded)]
        tokens = response["choices"][0]["logprobs"]["tokens"]
        tokens = tokens[len(prompt_encoded) : len(prompt_completion_encoded)]
        assert "".join(tokens) == output, f"Tokens do not match completion: {''.join(tokens)} != {output}"
    except Exception as e:
        raise RuntimeError(f"Generation API wrong response: {r.text}", e)
    return log_probs

load_tokenizer()

Loads the tokenizer for the model.

If the tokenizer is not already loaded, this method will import the transformers library and load the tokenizer using the model name or tokenizer name. If _MOCK_TOKENIZER is set, it will use that instead.

Raises:

  • ValueError

    If neither self.tokenizer_name nor self.model_name is provided and _MOCK_TOKENIZER is not set.

Source code in tapeagents/llms.py
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
def load_tokenizer(self):
    """
    Loads the tokenizer for the model.

    If the tokenizer is not already loaded, this method will import the
    `transformers` library and load the tokenizer using the model name or
    tokenizer name. If `_MOCK_TOKENIZER` is set, it will use that instead.

    Raises:
        ValueError: If neither `self.tokenizer_name` nor `self.model_name`
                    is provided and `_MOCK_TOKENIZER` is not set.
    """
    if self.tokenizer is None:
        import transformers

        name = _MOCK_TOKENIZER if _MOCK_TOKENIZER else (self.tokenizer_name or self.model_name)
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(name)

make_training_text(prompt, output)

Generates training text from a given prompt and LLM output.

This method loads the tokenizer and uses it to create training text suitable for training a language model.

Parameters:

  • prompt (Prompt) –

    The input prompt to generate training text from.

  • output (LLMOutput) –

    The output from the language model to be used in training.

Returns:

  • TrainingText ( TrainingText ) –

    The generated training text.

Source code in tapeagents/llms.py
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
def make_training_text(self, prompt: Prompt, output: LLMOutput) -> TrainingText:
    """
    Generates training text from a given prompt and LLM output.

    This method loads the tokenizer and uses it to create training text
    suitable for training a language model.

    Args:
        prompt (Prompt): The input prompt to generate training text from.
        output (LLMOutput): The output from the language model to be used in training.

    Returns:
        TrainingText: The generated training text.
    """
    self.load_tokenizer()
    return trainable_llm_make_training_text(prompt, output, self.tokenizer)

closest_prompt(prompt_key, known_prompts)

Finds the closest matching prompt from a list of known prompts based on a Levenshtein similarity ratio.

Parameters:

  • prompt_key (str) –

    The prompt to compare against the known prompts.

  • known_prompts (list[str]) –

    A list of known prompts to compare with the prompt_key.

Returns:

  • tuple[str, float]

    tuple[str, float]: A tuple containing the closest matching prompt and its similarity score. If no prompts are found, returns an empty string and a score of 0.0.

Source code in tapeagents/llms.py
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
def closest_prompt(prompt_key: str, known_prompts: list[str]) -> tuple[str, float]:
    """
    Finds the closest matching prompt from a list of known prompts based on a Levenshtein similarity ratio.

    Args:
        prompt_key (str): The prompt to compare against the known prompts.
        known_prompts (list[str]): A list of known prompts to compare with the prompt_key.

    Returns:
        tuple[str, float]: A tuple containing the closest matching prompt and its similarity score.
                           If no prompts are found, returns an empty string and a score of 0.0.
    """
    ratios = [(k, ratio(prompt_key, k, score_cutoff=0.5)) for k in known_prompts]
    if not len(ratios):
        return "", 0.0
    ratios = sorted(ratios, key=lambda x: x[1], reverse=True)
    closest, score = sorted(ratios, key=lambda x: x[1], reverse=True)[0]
    return closest, score

trainable_llm_make_training_text(prompt, output, tokenizer)

Generates training text for LLM fine-tuning by combining prompt and output using tokenizer's chat template.

Parameters:

  • prompt (Prompt) –

    The input prompt containing conversation messages.

  • output (LLMOutput) –

    The model's output/response.

  • tokenizer (PreTrainedTokenizer) –

    The tokenizer used to format the conversation.

Returns:

  • TrainingText ( TrainingText ) –

    A dataclass containing:

    • text (str): The formatted conversation text
    • n_predicted (int): Length of the output text portion
Note
  • Uses tokenizer's chat template to format conversations
  • Removes BOS token if present in the beginning of the text
Source code in tapeagents/llms.py
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
def trainable_llm_make_training_text(prompt: Prompt, output: LLMOutput, tokenizer) -> TrainingText:
    """
    Generates training text for LLM fine-tuning by combining prompt and output using tokenizer's chat template.

    Args:
        prompt (Prompt): The input prompt containing conversation messages.
        output (LLMOutput): The model's output/response.
        tokenizer (PreTrainedTokenizer): The tokenizer used to format the conversation.

    Returns:
        TrainingText: A dataclass containing:

            - text (str): The formatted conversation text
            - n_predicted (int): Length of the output text portion

    Note:
        - Uses tokenizer's chat template to format conversations
        - Removes BOS token if present in the beginning of the text
    """
    prompt_text = tokenizer.apply_chat_template(
        conversation=prompt.messages, tokenize=False, add_generation_prompt=True
    )
    text = tokenizer.apply_chat_template(
        prompt.messages + [{"role": "assistant", "content": output.content}], tokenize=False
    )
    output_text = text[len(prompt_text) :]

    if tokenizer.bos_token and text.startswith(tokenizer.bos_token):
        text = text[len(tokenizer.bos_token) :]

    return TrainingText(text=text, n_predicted=len(output_text))