Skip to content

Batch

Batch processing of tapes.

Functions:

batch_main_loop(agent, tapes, environments, n_workers=_DEFAULT_N_WORKERS, strict=False, max_loops=-1)

Continue tapes in parallel using an agent.

Source code in tapeagents/batch.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def batch_main_loop(
    agent: Agent[TapeType],
    tapes: list[TapeType],
    environments: Environment | list[Environment],
    n_workers: int = _DEFAULT_N_WORKERS,
    strict: bool = False,
    max_loops: int = -1,
) -> Generator[TapeType, None, None]:
    """Continue tapes in parallel using an agent."""
    if not isinstance(environments, list):
        environments = [environments] * len(tapes)

    def worker_func(input: tuple[TapeType, Environment]) -> TapeType | Exception:
        start_tape, env = input
        try:
            result = main_loop(agent, start_tape, env, max_loops=max_loops).get_final_tape()
        except Exception as e:
            if is_debug_mode() or strict:
                return e
            return start_tape.model_copy(
                update=dict(metadata=TapeMetadata(parent_id=start_tape.metadata.id, error=traceback.format_exc()))
            )
        result.metadata.parent_id = start_tape.metadata.id
        return result

    processor = choose_processor(n_workers=n_workers)
    for smth in processor(zip(tapes, environments), worker_func):
        if isinstance(smth, Tape):
            yield smth
        else:
            raise smth