Skip to content

Data

Functions:

  • mask_labels

    This function creates labels from a sequence of input ids by masking

  • validate_spans

    Make sure the spans are valid, don't overlap, and are in order.

mask_labels(input_ids, offset_mapping, predicted_spans, masked_token_id=MASKED_TOKEN_ID)

This function creates labels from a sequence of input ids by masking the tokens that do not have any overlap with the character spans that are designated for prediction. The labels can then be used to train a model to predict everything except the masked tokens.

The function also returns a list of midpoints for splitting the labels into a source and a target. The source is the part of the labels that is used to predict the target. There is one midpoint for each span that is designated for prediction. Each midpoint is the index of the first token that overlaps with the corresponding span.

Parameters:

  • input_ids (Sequence[int]) –

    A sequence of token ids.

  • offset_mapping (Iterable[tuple[int, int]]) –

    The offset mapping returned by the tokenizer.

  • predicted_spans (Iterable[Iterable[int]]) –

    The character spans that are designated for prediction. The spans are given as a sequence of two-element sequences, where the first element is the beginning of the span (inclusive) and the second element is the end of the span (not inclusive).

Returns:

  • tuple[list[int], list[int]]

    tuple[list[int], list[int]]: A tuple of masked labels and corresponding midpoints for splitting the labels into a source and a target.

Source code in tapeagents/finetune/data.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def mask_labels(
    input_ids: Sequence[int],
    offset_mapping: Iterable[tuple[int, int]],
    predicted_spans: Iterable[Iterable[int]],
    masked_token_id: int = MASKED_TOKEN_ID,
) -> tuple[list[int], list[int]]:
    """
    This function creates labels from a sequence of input ids by masking
    the tokens that do not have any overlap with the character spans that
    are designated for prediction. The labels can then be used to train
    a model to predict everything except the masked tokens.

    The function also returns a list of midpoints for splitting the
    labels into a source and a target. The source is the part of the
    labels that is used to predict the target. There is one midpoint
    for each span that is designated for prediction. Each midpoint is
    the index of the first token that overlaps with the corresponding
    span.

    Args:
        input_ids (Sequence[int]): A sequence of token ids.
        offset_mapping (Iterable[tuple[int, int]]): The offset mapping
            returned by the tokenizer.
        predicted_spans (Iterable[Iterable[int]]): The character spans
            that are designated for prediction. The spans are given as
            a sequence of two-element sequences, where the first element
            is the beginning of the span (inclusive) and the second
            element is the end of the span (not inclusive).

    Returns:
        tuple[list[int], list[int]]: A tuple of masked labels and
            corresponding midpoints for splitting the labels into
            a source and a target.
    """
    labels = [masked_token_id] * len(input_ids)
    midpoints = []
    # TODO: Make this O(n_tokens) instead of O(n_tokens * n_spans)
    for span_begin, span_end in predicted_spans:
        midpoint_found = False
        for i, (offset_begin, offset_end) in enumerate(offset_mapping):
            # visual inspection of the results shows that this is the correct way to check
            if offset_begin < span_end and span_begin < offset_end:
                if not midpoint_found:
                    midpoints.append(i)
                    midpoint_found = True
                labels[i] = input_ids[i]
    return labels, midpoints

validate_spans(text, predicted_spans)

Make sure the spans are valid, don't overlap, and are in order.

Source code in tapeagents/finetune/data.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def validate_spans(text: str, predicted_spans: list[tuple[int, int]]) -> None:
    """Make sure the spans are valid, don't overlap, and are in order."""
    for start, end in predicted_spans:
        if start < 0 or end > len(text):
            raise ValueError(f"Span {start}:{end} is out of bounds for text {text!r}")
        if start > end:
            raise ValueError(f"Span {start}:{end} is invalid")
    for (start1, end1), (start2, end2) in zip(predicted_spans, predicted_spans[1:]):
        # Make sure the second span starts after the first one ends.
        if start2 < end1:
            raise ValueError(
                f"Spans {start1}:{end1} ({text[start1:end1]!r}) and {start2}:{end2} ({text[start2:end2]!r}) overlap"
            )