Skip to content

Annotation API

Package Exports

Annotation module for human evaluation of proactive agent proposals.

Annotation

Bases: BaseModel

A single annotation record.

Source code in pare/annotation/models.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
class Annotation(BaseModel):
    """A single annotation record."""

    annotation_id: str
    sample_id: str
    annotator_id: str
    human_decision: TernaryDecision
    gather_context_rationale: str | None = None
    timestamp: str

    @classmethod
    def create(
        cls,
        sample_id: str,
        annotator_id: str,
        human_decision: TernaryDecision,
        gather_context_rationale: str | None = None,
    ) -> Annotation:
        """Create a new annotation record.

        Args:
            sample_id: The sample being annotated.
            annotator_id: The annotator's anonymous ID.
            human_decision: The human's accept/reject/gather_context decision.
            gather_context_rationale: Free-text rationale when decision is gather_context.

        Returns:
            A new Annotation instance.
        """
        return cls(
            annotation_id=str(uuid4()),
            sample_id=sample_id,
            annotator_id=annotator_id,
            human_decision=human_decision,
            gather_context_rationale=gather_context_rationale,
            timestamp=datetime.now().isoformat(),
        )

    def to_csv_row(self) -> str:
        """Convert to CSV row string."""
        output = io.StringIO()
        writer = csv.writer(output)
        writer.writerow([
            self.annotation_id,
            self.sample_id,
            self.annotator_id,
            self.human_decision,
            self.gather_context_rationale or "",
            self.timestamp,
        ])
        return output.getvalue()

    @classmethod
    def csv_header(cls) -> str:
        """Get CSV header row."""
        return "annotation_id,sample_id,annotator_id,human_decision,gather_context_rationale,timestamp\n"

create(sample_id, annotator_id, human_decision, gather_context_rationale=None) classmethod

Create a new annotation record.

Parameters:

Name Type Description Default
sample_id str

The sample being annotated.

required
annotator_id str

The annotator's anonymous ID.

required
human_decision TernaryDecision

The human's accept/reject/gather_context decision.

required
gather_context_rationale str | None

Free-text rationale when decision is gather_context.

None

Returns:

Type Description
Annotation

A new Annotation instance.

Source code in pare/annotation/models.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
@classmethod
def create(
    cls,
    sample_id: str,
    annotator_id: str,
    human_decision: TernaryDecision,
    gather_context_rationale: str | None = None,
) -> Annotation:
    """Create a new annotation record.

    Args:
        sample_id: The sample being annotated.
        annotator_id: The annotator's anonymous ID.
        human_decision: The human's accept/reject/gather_context decision.
        gather_context_rationale: Free-text rationale when decision is gather_context.

    Returns:
        A new Annotation instance.
    """
    return cls(
        annotation_id=str(uuid4()),
        sample_id=sample_id,
        annotator_id=annotator_id,
        human_decision=human_decision,
        gather_context_rationale=gather_context_rationale,
        timestamp=datetime.now().isoformat(),
    )

csv_header() classmethod

Get CSV header row.

Source code in pare/annotation/models.py
279
280
281
282
@classmethod
def csv_header(cls) -> str:
    """Get CSV header row."""
    return "annotation_id,sample_id,annotator_id,human_decision,gather_context_rationale,timestamp\n"

to_csv_row()

Convert to CSV row string.

Source code in pare/annotation/models.py
265
266
267
268
269
270
271
272
273
274
275
276
277
def to_csv_row(self) -> str:
    """Convert to CSV row string."""
    output = io.StringIO()
    writer = csv.writer(output)
    writer.writerow([
        self.annotation_id,
        self.sample_id,
        self.annotator_id,
        self.human_decision,
        self.gather_context_rationale or "",
        self.timestamp,
    ])
    return output.getvalue()

Sample

Bases: BaseModel

A sample for annotation (loaded from parquet).

Source code in pare/annotation/models.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
class Sample(BaseModel):
    """A sample for annotation (loaded from parquet)."""

    sample_id: str
    scenario_id: str
    run_number: int
    proactive_model_id: str
    user_model_id: str
    trace_file: str
    user_agent_decision: TernaryDecision
    agent_proposal: str
    meta_task_description: str
    llm_input: str
    final_decision: bool
    gather_context_delta: str | None = None
    tutorial: bool = False
    correct_decision: TernaryDecision | None = None
    explanation: str | None = None

    # Message types to strip from UI rendering
    _STRIPPED_MSG_TYPES: frozenset[str] = frozenset({
        "system_prompt",
        "available_tools",
        "current_app_state",
        "unknown",
    })

    @staticmethod
    def _extract_observation_content(content: str) -> str:
        r"""Strip Meta-ARE boilerplate from tool observation content.

        Extracts the text between ``***`` delimiters from format:
        ``[OUTPUT OF STEP N] Observation:\\n***\\n<content>\\n***``

        Args:
            content: Raw tool-response content string.

        Returns:
            Cleaned observation text, or original content if pattern doesn't match.
        """
        match = re.search(r"\*\*\*\n(.*?)\n\*\*\*", content, re.DOTALL)
        return match.group(1).strip() if match else content

    @staticmethod
    def _extract_notification_content(content: str) -> str:
        r"""Strip wrapper from environment notification content.

        Extracts text between ``***`` delimiters from format:
        ``Environment notifications updates:\\n***\\n<content>\\n***``

        Args:
            content: Raw environment notification content string.

        Returns:
            Cleaned notification text.
        """
        match = re.search(r"\*\*\*\n(.*?)\n\*\*\*", content, re.DOTALL)
        return match.group(1).strip() if match else content

    @staticmethod
    def _extract_proposal_message(content: str) -> str:
        """Extract the proposal message from [TASK] formatted content.

        Strips the ``[TASK]:``, ``Received at:``, ``Sender:``, ``Message:``,
        and ``Already read:`` framing added by Meta-ARE.

        Args:
            content: Raw proposal content with [TASK] framing.

        Returns:
            Clean proposal message text.
        """
        match = re.search(r"Message:\s*(.+?)(?:\nAlready read:|$)", content, re.DOTALL)
        if match:
            return match.group(1).strip()
        # Fallback: strip [TASK]: prefix at minimum
        if content.startswith("[TASK]:"):
            return content[len("[TASK]:") :].strip()
        return content

    @staticmethod
    def _filter_none_notifications(content: str) -> str:
        """Remove notification lines that contain only 'None' content.

        Filters out lines like ``[2025-11-18 09:00:01] None`` that represent
        environment ticks with no actual notification.

        Args:
            content: Notification content with potentially multiple lines.

        Returns:
            Filtered content with None-only lines removed. Empty string if all lines were None.
        """
        lines = content.strip().split("\n")
        filtered = [line for line in lines if not re.match(r"^\[[\d\-: ]+\]\s*None\s*$", line.strip())]
        return "\n".join(filtered)

    @staticmethod
    def _extract_tool_name(user_action_content: str) -> str:
        """Extract tool name from a user_action message content.

        Parses the ``Action: AppName__tool_name`` line from ReAct format.

        Args:
            user_action_content: The assistant role message content.

        Returns:
            Tool name string, or empty string if not found.
        """
        match = re.search(r"Action:\s*(\S+)", user_action_content)
        return match.group(1) if match else ""

    def to_api_response(self, progress_completed: int, progress_total: int) -> SampleResponse:
        """Convert to structured API response for the annotation UI.

        Parses llm_input JSON, filters out non-renderable message types,
        formats observations and notifications for human readability,
        and returns a typed SampleResponse.

        Args:
            progress_completed: Number of completed annotations.
            progress_total: Total number of annotations.

        Returns:
            SampleResponse with filtered, typed, formatted messages.
        """
        from pare.annotation.observation_formatter import ObservationFormatter, format_notification

        raw_messages: list[dict[str, object]] = json.loads(self.llm_input)

        messages: list[UIMessage] = []
        last_tool_name = ""
        for msg in raw_messages:
            msg_type_str = str(msg.get("msg_type", "unknown"))
            if msg_type_str in self._STRIPPED_MSG_TYPES:
                continue

            content = str(msg.get("content", ""))

            # Track tool name from user_action for formatting the next tool_observation
            if msg_type_str == "user_action":
                last_tool_name = self._extract_tool_name(content)

            # Format tool observations using ObservationFormatter
            if msg_type_str == "tool_observation":
                raw_obs = self._extract_observation_content(content)
                content = ObservationFormatter.format(last_tool_name, raw_obs)

            # Format proposals to strip [TASK] framing
            if msg_type_str == "proposal":
                content = self._extract_proposal_message(content)

            # Format notifications to strip hex IDs and filter None entries
            if msg_type_str == "environment_notification":
                raw_notif = self._extract_notification_content(content)
                raw_notif = self._filter_none_notifications(raw_notif)
                if not raw_notif.strip():
                    continue  # Skip entirely empty notification blocks
                content = format_notification(raw_notif)

            timestamp_val = msg.get("timestamp")
            timestamp = float(timestamp_val) if isinstance(timestamp_val, (int, float)) else None

            messages.append(
                UIMessage(
                    msg_type=MessageType(msg_type_str),
                    content=content,
                    timestamp=timestamp,
                )
            )

        return SampleResponse(
            sample_id=self.sample_id,
            scenario_context=self.meta_task_description if self.meta_task_description else None,
            messages=messages,
            progress_completed=progress_completed,
            progress_total=progress_total,
            tutorial=self.tutorial,
        )

to_api_response(progress_completed, progress_total)

Convert to structured API response for the annotation UI.

Parses llm_input JSON, filters out non-renderable message types, formats observations and notifications for human readability, and returns a typed SampleResponse.

Parameters:

Name Type Description Default
progress_completed int

Number of completed annotations.

required
progress_total int

Total number of annotations.

required

Returns:

Type Description
SampleResponse

SampleResponse with filtered, typed, formatted messages.

Source code in pare/annotation/models.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def to_api_response(self, progress_completed: int, progress_total: int) -> SampleResponse:
    """Convert to structured API response for the annotation UI.

    Parses llm_input JSON, filters out non-renderable message types,
    formats observations and notifications for human readability,
    and returns a typed SampleResponse.

    Args:
        progress_completed: Number of completed annotations.
        progress_total: Total number of annotations.

    Returns:
        SampleResponse with filtered, typed, formatted messages.
    """
    from pare.annotation.observation_formatter import ObservationFormatter, format_notification

    raw_messages: list[dict[str, object]] = json.loads(self.llm_input)

    messages: list[UIMessage] = []
    last_tool_name = ""
    for msg in raw_messages:
        msg_type_str = str(msg.get("msg_type", "unknown"))
        if msg_type_str in self._STRIPPED_MSG_TYPES:
            continue

        content = str(msg.get("content", ""))

        # Track tool name from user_action for formatting the next tool_observation
        if msg_type_str == "user_action":
            last_tool_name = self._extract_tool_name(content)

        # Format tool observations using ObservationFormatter
        if msg_type_str == "tool_observation":
            raw_obs = self._extract_observation_content(content)
            content = ObservationFormatter.format(last_tool_name, raw_obs)

        # Format proposals to strip [TASK] framing
        if msg_type_str == "proposal":
            content = self._extract_proposal_message(content)

        # Format notifications to strip hex IDs and filter None entries
        if msg_type_str == "environment_notification":
            raw_notif = self._extract_notification_content(content)
            raw_notif = self._filter_none_notifications(raw_notif)
            if not raw_notif.strip():
                continue  # Skip entirely empty notification blocks
            content = format_notification(raw_notif)

        timestamp_val = msg.get("timestamp")
        timestamp = float(timestamp_val) if isinstance(timestamp_val, (int, float)) else None

        messages.append(
            UIMessage(
                msg_type=MessageType(msg_type_str),
                content=content,
                timestamp=timestamp,
            )
        )

    return SampleResponse(
        sample_id=self.sample_id,
        scenario_context=self.meta_task_description if self.meta_task_description else None,
        messages=messages,
        progress_completed=progress_completed,
        progress_total=progress_total,
        tutorial=self.tutorial,
    )

Data Models and Parsing

Data models for the annotation module.

Annotation

Bases: BaseModel

A single annotation record.

Source code in pare/annotation/models.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
class Annotation(BaseModel):
    """A single annotation record."""

    annotation_id: str
    sample_id: str
    annotator_id: str
    human_decision: TernaryDecision
    gather_context_rationale: str | None = None
    timestamp: str

    @classmethod
    def create(
        cls,
        sample_id: str,
        annotator_id: str,
        human_decision: TernaryDecision,
        gather_context_rationale: str | None = None,
    ) -> Annotation:
        """Create a new annotation record.

        Args:
            sample_id: The sample being annotated.
            annotator_id: The annotator's anonymous ID.
            human_decision: The human's accept/reject/gather_context decision.
            gather_context_rationale: Free-text rationale when decision is gather_context.

        Returns:
            A new Annotation instance.
        """
        return cls(
            annotation_id=str(uuid4()),
            sample_id=sample_id,
            annotator_id=annotator_id,
            human_decision=human_decision,
            gather_context_rationale=gather_context_rationale,
            timestamp=datetime.now().isoformat(),
        )

    def to_csv_row(self) -> str:
        """Convert to CSV row string."""
        output = io.StringIO()
        writer = csv.writer(output)
        writer.writerow([
            self.annotation_id,
            self.sample_id,
            self.annotator_id,
            self.human_decision,
            self.gather_context_rationale or "",
            self.timestamp,
        ])
        return output.getvalue()

    @classmethod
    def csv_header(cls) -> str:
        """Get CSV header row."""
        return "annotation_id,sample_id,annotator_id,human_decision,gather_context_rationale,timestamp\n"

create(sample_id, annotator_id, human_decision, gather_context_rationale=None) classmethod

Create a new annotation record.

Parameters:

Name Type Description Default
sample_id str

The sample being annotated.

required
annotator_id str

The annotator's anonymous ID.

required
human_decision TernaryDecision

The human's accept/reject/gather_context decision.

required
gather_context_rationale str | None

Free-text rationale when decision is gather_context.

None

Returns:

Type Description
Annotation

A new Annotation instance.

Source code in pare/annotation/models.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
@classmethod
def create(
    cls,
    sample_id: str,
    annotator_id: str,
    human_decision: TernaryDecision,
    gather_context_rationale: str | None = None,
) -> Annotation:
    """Create a new annotation record.

    Args:
        sample_id: The sample being annotated.
        annotator_id: The annotator's anonymous ID.
        human_decision: The human's accept/reject/gather_context decision.
        gather_context_rationale: Free-text rationale when decision is gather_context.

    Returns:
        A new Annotation instance.
    """
    return cls(
        annotation_id=str(uuid4()),
        sample_id=sample_id,
        annotator_id=annotator_id,
        human_decision=human_decision,
        gather_context_rationale=gather_context_rationale,
        timestamp=datetime.now().isoformat(),
    )

csv_header() classmethod

Get CSV header row.

Source code in pare/annotation/models.py
279
280
281
282
@classmethod
def csv_header(cls) -> str:
    """Get CSV header row."""
    return "annotation_id,sample_id,annotator_id,human_decision,gather_context_rationale,timestamp\n"

to_csv_row()

Convert to CSV row string.

Source code in pare/annotation/models.py
265
266
267
268
269
270
271
272
273
274
275
276
277
def to_csv_row(self) -> str:
    """Convert to CSV row string."""
    output = io.StringIO()
    writer = csv.writer(output)
    writer.writerow([
        self.annotation_id,
        self.sample_id,
        self.annotator_id,
        self.human_decision,
        self.gather_context_rationale or "",
        self.timestamp,
    ])
    return output.getvalue()

MessageType

Bases: StrEnum

Classification of messages for UI rendering.

Source code in pare/annotation/models.py
18
19
20
21
22
23
24
class MessageType(StrEnum):
    """Classification of messages for UI rendering."""

    USER_ACTION = "user_action"
    TOOL_OBSERVATION = "tool_observation"
    PROPOSAL = "proposal"
    ENVIRONMENT_NOTIFICATION = "environment_notification"

Sample

Bases: BaseModel

A sample for annotation (loaded from parquet).

Source code in pare/annotation/models.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
class Sample(BaseModel):
    """A sample for annotation (loaded from parquet)."""

    sample_id: str
    scenario_id: str
    run_number: int
    proactive_model_id: str
    user_model_id: str
    trace_file: str
    user_agent_decision: TernaryDecision
    agent_proposal: str
    meta_task_description: str
    llm_input: str
    final_decision: bool
    gather_context_delta: str | None = None
    tutorial: bool = False
    correct_decision: TernaryDecision | None = None
    explanation: str | None = None

    # Message types to strip from UI rendering
    _STRIPPED_MSG_TYPES: frozenset[str] = frozenset({
        "system_prompt",
        "available_tools",
        "current_app_state",
        "unknown",
    })

    @staticmethod
    def _extract_observation_content(content: str) -> str:
        r"""Strip Meta-ARE boilerplate from tool observation content.

        Extracts the text between ``***`` delimiters from format:
        ``[OUTPUT OF STEP N] Observation:\\n***\\n<content>\\n***``

        Args:
            content: Raw tool-response content string.

        Returns:
            Cleaned observation text, or original content if pattern doesn't match.
        """
        match = re.search(r"\*\*\*\n(.*?)\n\*\*\*", content, re.DOTALL)
        return match.group(1).strip() if match else content

    @staticmethod
    def _extract_notification_content(content: str) -> str:
        r"""Strip wrapper from environment notification content.

        Extracts text between ``***`` delimiters from format:
        ``Environment notifications updates:\\n***\\n<content>\\n***``

        Args:
            content: Raw environment notification content string.

        Returns:
            Cleaned notification text.
        """
        match = re.search(r"\*\*\*\n(.*?)\n\*\*\*", content, re.DOTALL)
        return match.group(1).strip() if match else content

    @staticmethod
    def _extract_proposal_message(content: str) -> str:
        """Extract the proposal message from [TASK] formatted content.

        Strips the ``[TASK]:``, ``Received at:``, ``Sender:``, ``Message:``,
        and ``Already read:`` framing added by Meta-ARE.

        Args:
            content: Raw proposal content with [TASK] framing.

        Returns:
            Clean proposal message text.
        """
        match = re.search(r"Message:\s*(.+?)(?:\nAlready read:|$)", content, re.DOTALL)
        if match:
            return match.group(1).strip()
        # Fallback: strip [TASK]: prefix at minimum
        if content.startswith("[TASK]:"):
            return content[len("[TASK]:") :].strip()
        return content

    @staticmethod
    def _filter_none_notifications(content: str) -> str:
        """Remove notification lines that contain only 'None' content.

        Filters out lines like ``[2025-11-18 09:00:01] None`` that represent
        environment ticks with no actual notification.

        Args:
            content: Notification content with potentially multiple lines.

        Returns:
            Filtered content with None-only lines removed. Empty string if all lines were None.
        """
        lines = content.strip().split("\n")
        filtered = [line for line in lines if not re.match(r"^\[[\d\-: ]+\]\s*None\s*$", line.strip())]
        return "\n".join(filtered)

    @staticmethod
    def _extract_tool_name(user_action_content: str) -> str:
        """Extract tool name from a user_action message content.

        Parses the ``Action: AppName__tool_name`` line from ReAct format.

        Args:
            user_action_content: The assistant role message content.

        Returns:
            Tool name string, or empty string if not found.
        """
        match = re.search(r"Action:\s*(\S+)", user_action_content)
        return match.group(1) if match else ""

    def to_api_response(self, progress_completed: int, progress_total: int) -> SampleResponse:
        """Convert to structured API response for the annotation UI.

        Parses llm_input JSON, filters out non-renderable message types,
        formats observations and notifications for human readability,
        and returns a typed SampleResponse.

        Args:
            progress_completed: Number of completed annotations.
            progress_total: Total number of annotations.

        Returns:
            SampleResponse with filtered, typed, formatted messages.
        """
        from pare.annotation.observation_formatter import ObservationFormatter, format_notification

        raw_messages: list[dict[str, object]] = json.loads(self.llm_input)

        messages: list[UIMessage] = []
        last_tool_name = ""
        for msg in raw_messages:
            msg_type_str = str(msg.get("msg_type", "unknown"))
            if msg_type_str in self._STRIPPED_MSG_TYPES:
                continue

            content = str(msg.get("content", ""))

            # Track tool name from user_action for formatting the next tool_observation
            if msg_type_str == "user_action":
                last_tool_name = self._extract_tool_name(content)

            # Format tool observations using ObservationFormatter
            if msg_type_str == "tool_observation":
                raw_obs = self._extract_observation_content(content)
                content = ObservationFormatter.format(last_tool_name, raw_obs)

            # Format proposals to strip [TASK] framing
            if msg_type_str == "proposal":
                content = self._extract_proposal_message(content)

            # Format notifications to strip hex IDs and filter None entries
            if msg_type_str == "environment_notification":
                raw_notif = self._extract_notification_content(content)
                raw_notif = self._filter_none_notifications(raw_notif)
                if not raw_notif.strip():
                    continue  # Skip entirely empty notification blocks
                content = format_notification(raw_notif)

            timestamp_val = msg.get("timestamp")
            timestamp = float(timestamp_val) if isinstance(timestamp_val, (int, float)) else None

            messages.append(
                UIMessage(
                    msg_type=MessageType(msg_type_str),
                    content=content,
                    timestamp=timestamp,
                )
            )

        return SampleResponse(
            sample_id=self.sample_id,
            scenario_context=self.meta_task_description if self.meta_task_description else None,
            messages=messages,
            progress_completed=progress_completed,
            progress_total=progress_total,
            tutorial=self.tutorial,
        )

to_api_response(progress_completed, progress_total)

Convert to structured API response for the annotation UI.

Parses llm_input JSON, filters out non-renderable message types, formats observations and notifications for human readability, and returns a typed SampleResponse.

Parameters:

Name Type Description Default
progress_completed int

Number of completed annotations.

required
progress_total int

Total number of annotations.

required

Returns:

Type Description
SampleResponse

SampleResponse with filtered, typed, formatted messages.

Source code in pare/annotation/models.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def to_api_response(self, progress_completed: int, progress_total: int) -> SampleResponse:
    """Convert to structured API response for the annotation UI.

    Parses llm_input JSON, filters out non-renderable message types,
    formats observations and notifications for human readability,
    and returns a typed SampleResponse.

    Args:
        progress_completed: Number of completed annotations.
        progress_total: Total number of annotations.

    Returns:
        SampleResponse with filtered, typed, formatted messages.
    """
    from pare.annotation.observation_formatter import ObservationFormatter, format_notification

    raw_messages: list[dict[str, object]] = json.loads(self.llm_input)

    messages: list[UIMessage] = []
    last_tool_name = ""
    for msg in raw_messages:
        msg_type_str = str(msg.get("msg_type", "unknown"))
        if msg_type_str in self._STRIPPED_MSG_TYPES:
            continue

        content = str(msg.get("content", ""))

        # Track tool name from user_action for formatting the next tool_observation
        if msg_type_str == "user_action":
            last_tool_name = self._extract_tool_name(content)

        # Format tool observations using ObservationFormatter
        if msg_type_str == "tool_observation":
            raw_obs = self._extract_observation_content(content)
            content = ObservationFormatter.format(last_tool_name, raw_obs)

        # Format proposals to strip [TASK] framing
        if msg_type_str == "proposal":
            content = self._extract_proposal_message(content)

        # Format notifications to strip hex IDs and filter None entries
        if msg_type_str == "environment_notification":
            raw_notif = self._extract_notification_content(content)
            raw_notif = self._filter_none_notifications(raw_notif)
            if not raw_notif.strip():
                continue  # Skip entirely empty notification blocks
            content = format_notification(raw_notif)

        timestamp_val = msg.get("timestamp")
        timestamp = float(timestamp_val) if isinstance(timestamp_val, (int, float)) else None

        messages.append(
            UIMessage(
                msg_type=MessageType(msg_type_str),
                content=content,
                timestamp=timestamp,
            )
        )

    return SampleResponse(
        sample_id=self.sample_id,
        scenario_context=self.meta_task_description if self.meta_task_description else None,
        messages=messages,
        progress_completed=progress_completed,
        progress_total=progress_total,
        tutorial=self.tutorial,
    )

SampleResponse

Bases: BaseModel

API response payload for a single annotation sample.

Source code in pare/annotation/models.py
35
36
37
38
39
40
41
42
43
class SampleResponse(BaseModel):
    """API response payload for a single annotation sample."""

    sample_id: str
    scenario_context: str | None
    messages: list[UIMessage]
    progress_completed: int
    progress_total: int
    tutorial: bool = False

UIMessage

Bases: BaseModel

A single renderable message for the annotation UI.

Source code in pare/annotation/models.py
27
28
29
30
31
32
class UIMessage(BaseModel):
    """A single renderable message for the annotation UI."""

    msg_type: MessageType
    content: str
    timestamp: float | None = None

Proposal-centric trace parser for extracting decision points.

Single forward scan algorithm: 1. Identify agents by system prompt content 2. Truncate trace at first execute agent tool call 3. Pair proposals with accept/reject decisions 4. Classify: direct accept/reject vs gather_context 5. Extract and annotate llm_input at proposal point

extract_decision_points(trace_path, proactive_model_id, user_model_id='unknown')

Extract all decision points from a trace file.

Parameters:

Name Type Description Default
trace_path Path

Path to the trace JSON file.

required
proactive_model_id str

The proactive model identifier.

required
user_model_id str

The user model identifier.

'unknown'

Returns:

Type Description
list[DecisionPoint]

List of DecisionPoint objects, one per valid proposal-decision pair.

Source code in pare/trajectory/trace_parser.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def extract_decision_points(
    trace_path: Path,
    proactive_model_id: str,
    user_model_id: str = "unknown",
) -> list[DecisionPoint]:
    """Extract all decision points from a trace file.

    Args:
        trace_path: Path to the trace JSON file.
        proactive_model_id: The proactive model identifier.
        user_model_id: The user model identifier.

    Returns:
        List of DecisionPoint objects, one per valid proposal-decision pair.
    """
    with open(trace_path) as f:
        data = json.load(f)

    logs = data.get("world_logs", [])
    # Parse string-encoded log entries, skip malformed ones
    parsed_logs = []
    for log_entry in logs:
        if isinstance(log_entry, str):
            try:
                parsed_logs.append(json.loads(log_entry))
            except json.JSONDecodeError:
                logger.warning(f"Malformed log entry in {trace_path}, skipping: {log_entry[:100]}")
                continue
        else:
            parsed_logs.append(log_entry)
    logs = parsed_logs

    # Check for RateLimitError
    raw_text = json.dumps(data)
    if "RateLimitError" in raw_text:
        logger.warning(f"Skipping trace with RateLimitError: {trace_path}")
        return []

    # Extract metadata
    metadata = data.get("metadata", {})
    definition = metadata.get("definition", {})
    scenario_id = definition.get("scenario_id", trace_path.stem.rsplit("_run_", 1)[0])
    run_number = definition.get("run_number", 1)

    # Extract meta_task_description from system prompt
    meta_task_description = _extract_meta_task_description(logs)

    # Identify agents
    agents = _identify_agents(logs)
    user_id = agents.get("user")
    observe_id = agents.get("observe")
    execute_id = agents.get("execute")

    if not user_id or not observe_id:
        logger.warning(f"Could not identify user or observe agent in {trace_path}")
        return []

    # Find execute agent cutoff
    execute_cutoff = _find_execute_cutoff(logs, execute_id)

    # Find proposal-decision pairs and build DecisionPoints
    return _extract_pairs(
        logs=logs,
        user_id=user_id,
        observe_id=observe_id,
        execute_cutoff=execute_cutoff,
        scenario_id=scenario_id,
        run_number=run_number,
        proactive_model_id=proactive_model_id,
        user_model_id=user_model_id,
        trace_path=trace_path,
        meta_task_description=meta_task_description,
    )

Observation formatter for converting raw tool observations to human-readable text.

Used by Sample.to_api_response() to format tool observations and notifications for human annotators. A redesign using proper typed object parsing is tracked as a separate feature (see observation-formatter-redesign feature doc).

This module provides formatters for all return types used in PARE apps.

ObservationFormatter

Formats raw observations into human-readable displays.

Source code in pare/annotation/observation_formatter.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class ObservationFormatter:
    """Formats raw observations into human-readable displays."""

    @staticmethod
    def format(tool_name: str, raw_observation: Any, tool_args: dict[str, Any] | None = None) -> str:  # noqa: ANN401, C901
        """Format an observation based on the tool that produced it.

        Args:
            tool_name: The name of the tool (e.g., "Emails__list_emails").
            raw_observation: The raw observation data (usually a string repr of dataclass).
            tool_args: Optional dictionary of tool arguments for context.

        Returns:
            A human-readable formatted string.
        """
        obs_str = str(raw_observation)
        tool_args = tool_args or {}

        # Handle special simple cases
        if isinstance(raw_observation, str):
            if "Opened" in raw_observation and "App" in raw_observation:
                return raw_observation
            if "Switched to home screen" in raw_observation:
                return raw_observation
            if _is_uuid(raw_observation):
                return "Action completed."

        # Route based on tool name
        app_name = tool_name.split("__")[0] if "__" in tool_name else ""
        action_name = tool_name.split("__")[1] if "__" in tool_name else tool_name

        # Email app
        if app_name == "Emails":
            return _format_email_observation(action_name, obs_str, tool_args)

        # Contacts app
        if app_name == "Contacts":
            return _format_contacts_observation(action_name, obs_str, tool_args)

        # Calendar app
        if app_name == "Calendar":
            return _format_calendar_observation(action_name, obs_str, tool_args)

        # Messages app
        if app_name == "Messages":
            return _format_messages_observation(action_name, obs_str, tool_args)

        # Notes app
        if app_name == "Notes":
            return _format_notes_observation(action_name, obs_str, tool_args)

        # Reminder app
        if app_name == "Reminders":
            return _format_reminder_observation(action_name, obs_str, tool_args)

        # Cab app
        if app_name == "Cab":
            return _format_cab_observation(action_name, obs_str, tool_args)

        # Apartment app
        if app_name == "Apartment":
            return _format_apartment_observation(action_name, obs_str, tool_args)

        # Shopping app
        if app_name == "Shopping":
            return _format_shopping_observation(action_name, obs_str, tool_args)

        # System app
        if app_name == "System":
            return obs_str

        # PAREAgentUserInterface
        if app_name == "PAREAgentUserInterface":
            if _is_uuid(obs_str.strip()):
                return "Response recorded."
            return obs_str

        # Default: truncate if too long
        if len(obs_str) > 500:
            return obs_str[:500] + "..."
        return obs_str

format(tool_name, raw_observation, tool_args=None) staticmethod

Format an observation based on the tool that produced it.

Parameters:

Name Type Description Default
tool_name str

The name of the tool (e.g., "Emails__list_emails").

required
raw_observation Any

The raw observation data (usually a string repr of dataclass).

required
tool_args dict[str, Any] | None

Optional dictionary of tool arguments for context.

None

Returns:

Type Description
str

A human-readable formatted string.

Source code in pare/annotation/observation_formatter.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
@staticmethod
def format(tool_name: str, raw_observation: Any, tool_args: dict[str, Any] | None = None) -> str:  # noqa: ANN401, C901
    """Format an observation based on the tool that produced it.

    Args:
        tool_name: The name of the tool (e.g., "Emails__list_emails").
        raw_observation: The raw observation data (usually a string repr of dataclass).
        tool_args: Optional dictionary of tool arguments for context.

    Returns:
        A human-readable formatted string.
    """
    obs_str = str(raw_observation)
    tool_args = tool_args or {}

    # Handle special simple cases
    if isinstance(raw_observation, str):
        if "Opened" in raw_observation and "App" in raw_observation:
            return raw_observation
        if "Switched to home screen" in raw_observation:
            return raw_observation
        if _is_uuid(raw_observation):
            return "Action completed."

    # Route based on tool name
    app_name = tool_name.split("__")[0] if "__" in tool_name else ""
    action_name = tool_name.split("__")[1] if "__" in tool_name else tool_name

    # Email app
    if app_name == "Emails":
        return _format_email_observation(action_name, obs_str, tool_args)

    # Contacts app
    if app_name == "Contacts":
        return _format_contacts_observation(action_name, obs_str, tool_args)

    # Calendar app
    if app_name == "Calendar":
        return _format_calendar_observation(action_name, obs_str, tool_args)

    # Messages app
    if app_name == "Messages":
        return _format_messages_observation(action_name, obs_str, tool_args)

    # Notes app
    if app_name == "Notes":
        return _format_notes_observation(action_name, obs_str, tool_args)

    # Reminder app
    if app_name == "Reminders":
        return _format_reminder_observation(action_name, obs_str, tool_args)

    # Cab app
    if app_name == "Cab":
        return _format_cab_observation(action_name, obs_str, tool_args)

    # Apartment app
    if app_name == "Apartment":
        return _format_apartment_observation(action_name, obs_str, tool_args)

    # Shopping app
    if app_name == "Shopping":
        return _format_shopping_observation(action_name, obs_str, tool_args)

    # System app
    if app_name == "System":
        return obs_str

    # PAREAgentUserInterface
    if app_name == "PAREAgentUserInterface":
        if _is_uuid(obs_str.strip()):
            return "Response recorded."
        return obs_str

    # Default: truncate if too long
    if len(obs_str) > 500:
        return obs_str[:500] + "..."
    return obs_str

format_notification(raw_notification, id_to_name_map=None)

Format a notification to be human-readable.

Converts raw notifications like: '[2025-11-18 09:00:05] New message from 22c41f3ff12fe5f2a0a02c1da9d15b57 in conversation xyz: Hello!'

To: '[2025-11-18 09:00:05] New message: Hello!'

Parameters:

Name Type Description Default
raw_notification str

The raw notification string.

required
id_to_name_map dict[str, str] | None

Optional mapping from IDs to human-readable names.

None

Returns:

Type Description
str

A cleaned notification string.

Source code in pare/annotation/observation_formatter.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def format_notification(raw_notification: str, id_to_name_map: dict[str, str] | None = None) -> str:
    """Format a notification to be human-readable.

    Converts raw notifications like:
    '[2025-11-18 09:00:05] New message from 22c41f3ff12fe5f2a0a02c1da9d15b57 in conversation xyz: Hello!'

    To:
    '[2025-11-18 09:00:05] New message: Hello!'

    Args:
        raw_notification: The raw notification string.
        id_to_name_map: Optional mapping from IDs to human-readable names.

    Returns:
        A cleaned notification string.
    """
    id_to_name_map = id_to_name_map or {}
    lines = []

    for line in raw_notification.strip().split("\n"):
        line = line.strip()
        if not line:
            continue

        # Handle message notifications - match any ID format (hex, conv_xxx, etc.)
        # Pattern: [timestamp] New message from <any_id> in [conversation] <any_id>: <content>
        msg_match = re.match(
            r"(\[[\d\-: ]+\])\s*New message from (\S+) in (?:conversation )?(\S+):\s*(.+)",
            line,
            re.IGNORECASE | re.DOTALL,
        )
        if msg_match:
            timestamp, sender_id, _conv_id, content = msg_match.groups()
            # Strip trailing colon from conv_id if present
            _conv_id = _conv_id.rstrip(":")
            sender_name = id_to_name_map.get(sender_id, None)
            if sender_name:
                lines.append(f"{timestamp} New message from {sender_name}: {content}")
            else:
                # Just show the content without the confusing IDs
                lines.append(f"{timestamp} New message: {content}")
            continue

        # Handle email notifications: "New email from <email>: <subject>"
        email_match = re.match(
            r"(\[[\d\-: ]+\])\s*New email from ([^:]+):\s*(.+)",
            line,
            re.IGNORECASE,
        )
        if email_match:
            timestamp, sender, subject = email_match.groups()
            lines.append(f"{timestamp} New email from {sender}: {subject}")
            continue

        # Handle calendar notifications
        cal_match = re.match(
            r"(\[[\d\-: ]+\])\s*(Upcoming event|Event reminder|Calendar):\s*(.+)",
            line,
            re.IGNORECASE,
        )
        if cal_match:
            timestamp, notif_type, content = cal_match.groups()
            lines.append(f"{timestamp} {notif_type}: {content}")
            continue

        # Handle calendar event with ID: "Calendar event <id> deleted/updated/created by <name>"
        # ID can be hex (a-f0-9) or readable format (word_word_123)
        cal_event_match = re.match(
            r"(\[[\d\-: ]+\])\s*Calendar event \S+ (deleted|updated|created|cancelled)(.*)$",
            line,
            re.IGNORECASE,
        )
        if cal_event_match:
            timestamp, action, rest = cal_event_match.groups()
            lines.append(f"{timestamp} Calendar event {action}{rest}")
            continue

        # Handle reminder notifications
        reminder_match = re.match(
            r"(\[[\d\-: ]+\])\s*(Reminder):\s*(.+)",
            line,
            re.IGNORECASE,
        )
        if reminder_match:
            timestamp, notif_type, content = reminder_match.groups()
            lines.append(f"{timestamp} {notif_type}: {content}")
            continue

        # Handle other notifications - just pass through
        lines.append(line)

    return "\n".join(lines)

Sampling and Server

Balanced sampler for annotation dataset creation.

balanced_sample_ternary(candidates, sample_size, seed=None)

Sample decision points balanced by both decision type and proactive model.

Algorithm: For each proactive model, create three pools (accept, reject, gather_context). Cycle through models, and within each model cycle through decision type pools, drawing one sample per pool per cycle. If a pool is exhausted, skip it. Stop when target count reached or all pools across all models are empty.

Parameters:

Name Type Description Default
candidates list[DecisionPoint]

List of candidate decision points.

required
sample_size int

Number of samples to select.

required
seed int | None

Random seed for reproducibility.

None

Returns:

Type Description
list[DecisionPoint]

List of selected TernaryDecisionPoint objects.

Source code in pare/annotation/sampler.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def balanced_sample_ternary(
    candidates: list[TernaryDecisionPoint],
    sample_size: int,
    seed: int | None = None,
) -> list[TernaryDecisionPoint]:
    """Sample decision points balanced by both decision type and proactive model.

    Algorithm: For each proactive model, create three pools (accept, reject, gather_context).
    Cycle through models, and within each model cycle through decision type pools, drawing
    one sample per pool per cycle. If a pool is exhausted, skip it. Stop when target count
    reached or all pools across all models are empty.

    Args:
        candidates: List of candidate decision points.
        sample_size: Number of samples to select.
        seed: Random seed for reproducibility.

    Returns:
        List of selected TernaryDecisionPoint objects.
    """
    if seed is not None:
        random.seed(seed)

    decision_types = ["accept", "reject", "gather_context"]

    # Group by model, then by decision type
    model_ids = sorted({c.proactive_model_id for c in candidates})
    model_pools: dict[str, dict[str, list[TernaryDecisionPoint]]] = {}
    for model_id in model_ids:
        model_pools[model_id] = {}
        for dt in decision_types:
            pool = [c for c in candidates if c.proactive_model_id == model_id and c.user_agent_decision == dt]
            random.shuffle(pool)
            model_pools[model_id][dt] = pool

    selected: list[TernaryDecisionPoint] = []

    # Cycle: for each model, draw one from each decision type pool
    while len(selected) < sample_size:
        any_picked = _draw_one_cycle(model_ids, decision_types, model_pools, selected, sample_size)
        if not any_picked:
            logger.warning(f"Ran out of candidates after selecting {len(selected)} samples")
            break

    # Log balance statistics
    accepts_count = len([s for s in selected if s.user_agent_decision == "accept"])
    rejects_count = len([s for s in selected if s.user_agent_decision == "reject"])
    gather_context_count = len([s for s in selected if s.user_agent_decision == "gather_context"])
    model_counts = {m: len([s for s in selected if s.proactive_model_id == m]) for m in model_ids}
    logger.info(
        f"Selected {len(selected)} samples: {accepts_count} accepts, {rejects_count} rejects, "
        f"{gather_context_count} gather_context. Per model: {model_counts}"
    )

    return selected

extract_all_decision_points_ternary(traces_dir, user_model_id, target_models=None)

Extract ternary decision points from all trace files in a directory.

Walks the traces directory, identifies model subdirectories, and extracts decision points from each no-noise trace file using the ternary parser.

Parameters:

Name Type Description Default
traces_dir Path

Root directory containing model subdirectories with traces.

required
user_model_id str

The user model that generated these traces.

required
target_models list[str] | None

If provided, only extract from these proactive model IDs.

None

Returns:

Type Description
list[DecisionPoint]

List of TernaryDecisionPoint objects from all matching traces.

Source code in pare/annotation/sampler.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def extract_all_decision_points_ternary(
    traces_dir: Path,
    user_model_id: str,
    target_models: list[str] | None = None,
) -> list[TernaryDecisionPoint]:
    """Extract ternary decision points from all trace files in a directory.

    Walks the traces directory, identifies model subdirectories, and extracts
    decision points from each no-noise trace file using the ternary parser.

    Args:
        traces_dir: Root directory containing model subdirectories with traces.
        user_model_id: The user model that generated these traces.
        target_models: If provided, only extract from these proactive model IDs.

    Returns:
        List of TernaryDecisionPoint objects from all matching traces.
    """
    all_dps: list[TernaryDecisionPoint] = []

    for model_dir in sorted(traces_dir.iterdir()):
        if not model_dir.is_dir() or not is_no_noise_trace(model_dir.name):
            continue

        model_id = extract_model_id_from_dir(model_dir.name)
        if not model_id:
            continue

        if target_models and model_id not in target_models:
            continue

        trace_files = sorted(model_dir.glob("*.json"))

        for trace_file in trace_files:
            dps = extract_ternary_decision_points(trace_file, model_id, user_model_id)
            all_dps.extend(dps)

        logger.info(f"Extracted {len(all_dps)} decision points from {model_id} ({len(trace_files)} traces)")

    logger.info(f"Total: {len(all_dps)} ternary decision points from {traces_dir}")
    return all_dps

extract_model_id_from_dir(dir_name)

Extract the proactive model ID from a trace subdirectory name.

Example: obs_gpt-5_exec_gpt-5_enmi_0_es_42_tfp_0.0 -> gpt-5

Parameters:

Name Type Description Default
dir_name str

The subdirectory name.

required

Returns:

Type Description
str

The extracted proactive model ID.

Source code in pare/annotation/sampler.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def extract_model_id_from_dir(dir_name: str) -> str:
    """Extract the proactive model ID from a trace subdirectory name.

    Example: obs_gpt-5_exec_gpt-5_enmi_0_es_42_tfp_0.0 -> gpt-5

    Args:
        dir_name: The subdirectory name.

    Returns:
        The extracted proactive model ID.
    """
    match = re.match(r"obs_([^_]+(?:_[^_]+)?(?:-[^_]+)?)_exec_", dir_name)
    if match:
        return match.group(1)
    return dir_name

is_no_noise_trace(dir_name)

Check if a trace directory is a no-noise trace.

No-noise traces have enmi_0 (environment noise = 0).

Parameters:

Name Type Description Default
dir_name str

The directory name.

required

Returns:

Type Description
bool

True if this is a no-noise trace.

Source code in pare/annotation/sampler.py
38
39
40
41
42
43
44
45
46
47
48
49
def is_no_noise_trace(dir_name: str) -> bool:
    """Check if a trace directory is a no-noise trace.

    No-noise traces have enmi_0 (environment noise = 0).

    Args:
        dir_name: The directory name.

    Returns:
        True if this is a no-noise trace.
    """
    return "enmi_0" in dir_name

sample_new_datapoints_ternary(traces_dir, samples_file, user_model_id, sample_size, seed=None, target_models=None)

Extract, sample, and save ternary decision points.

End-to-end function: extracts all ternary decision points from traces, deduplicates against existing samples, applies three-way balanced sampling, and saves the result to parquet.

Parameters:

Name Type Description Default
traces_dir Path

Root directory containing model subdirectories with traces.

required
samples_file Path

Path to the output parquet file (for dedup and save).

required
user_model_id str

The user model that generated these traces.

required
sample_size int

Number of new samples to select.

required
seed int | None

Random seed for reproducibility.

None
target_models list[str] | None

If provided, only extract from these proactive model IDs.

None

Returns:

Type Description
list[DecisionPoint]

List of newly selected TernaryDecisionPoint objects.

Source code in pare/annotation/sampler.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def sample_new_datapoints_ternary(
    traces_dir: Path,
    samples_file: Path,
    user_model_id: str,
    sample_size: int,
    seed: int | None = None,
    target_models: list[str] | None = None,
) -> list[TernaryDecisionPoint]:
    """Extract, sample, and save ternary decision points.

    End-to-end function: extracts all ternary decision points from traces,
    deduplicates against existing samples, applies three-way balanced sampling,
    and saves the result to parquet.

    Args:
        traces_dir: Root directory containing model subdirectories with traces.
        samples_file: Path to the output parquet file (for dedup and save).
        user_model_id: The user model that generated these traces.
        sample_size: Number of new samples to select.
        seed: Random seed for reproducibility.
        target_models: If provided, only extract from these proactive model IDs.

    Returns:
        List of newly selected TernaryDecisionPoint objects.
    """
    # Extract all decision points
    all_dps = extract_all_decision_points_ternary(traces_dir, user_model_id, target_models)

    if not all_dps:
        logger.warning("No decision points found in traces")
        return []

    # Deduplicate against existing samples
    existing_ids: set[str] = set()
    if samples_file.exists():
        existing_df = pl.read_parquet(samples_file)
        existing_ids = set(existing_df["sample_id"].to_list())
        logger.info(f"Found {len(existing_ids)} existing samples for deduplication")

    new_candidates = [dp for dp in all_dps if dp.sample_id not in existing_ids]
    logger.info(f"Found {len(new_candidates)} new candidates (filtered {len(all_dps) - len(new_candidates)} existing)")

    if not new_candidates:
        logger.warning("No new candidates available for sampling")
        return []

    # Three-way balanced sampling
    selected = balanced_sample_ternary(new_candidates, sample_size, seed)

    # Save to parquet
    save_samples_ternary(selected, samples_file)

    return selected

save_samples_ternary(samples, output_file)

Save ternary samples to parquet, checking schema compatibility.

If the output file already exists, validates that it uses the ternary schema (user_agent_decision as String, not Boolean). Raises SystemExit if an incompatible binary-schema parquet is found.

Parameters:

Name Type Description Default
samples list[DecisionPoint]

List of TernaryDecisionPoint objects to save.

required
output_file Path

Path to the output parquet file.

required

Returns:

Type Description
Path

Path to the samples file.

Source code in pare/annotation/sampler.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def save_samples_ternary(samples: list[TernaryDecisionPoint], output_file: Path) -> Path:
    """Save ternary samples to parquet, checking schema compatibility.

    If the output file already exists, validates that it uses the ternary schema
    (user_agent_decision as String, not Boolean). Raises SystemExit if an
    incompatible binary-schema parquet is found.

    Args:
        samples: List of TernaryDecisionPoint objects to save.
        output_file: Path to the output parquet file.

    Returns:
        Path to the samples file.
    """
    import sys

    if not samples:
        logger.warning("No samples to save")
        return output_file

    output_file.parent.mkdir(parents=True, exist_ok=True)
    new_df = pl.DataFrame([s.to_sample_dict() for s in samples])

    # Add tutorial columns with defaults (real samples are never tutorials)
    new_df = new_df.with_columns(
        pl.lit(False).alias("tutorial"),
        pl.lit(None).cast(pl.Utf8).alias("correct_decision"),
        pl.lit(None).cast(pl.Utf8).alias("explanation"),
    )

    if output_file.exists():
        existing_df = pl.read_parquet(output_file)

        # Check schema compatibility
        if "user_agent_decision" in existing_df.columns and existing_df["user_agent_decision"].dtype == pl.Boolean:
            logger.error(
                f"Existing parquet {output_file} uses binary schema (user_agent_decision: bool). "
                "Delete it and re-sample with the ternary pipeline."
            )
            sys.exit(1)

        # Align column order and use diagonal concat to handle type mismatches
        # (e.g., Null vs String when all values in a column are null)
        new_df = new_df.select(existing_df.columns)
        combined_df = pl.concat([existing_df, new_df], how="diagonal_relaxed")
        combined_df.write_parquet(output_file)
        logger.info(f"Appended {len(samples)} samples to {output_file} (total: {len(combined_df)})")
    else:
        new_df.write_parquet(output_file)
        logger.info(f"Created {output_file} with {len(samples)} samples")

    return output_file

FastAPI server for the annotation interface.

AnnotationRequest

Bases: BaseModel

Request body for submitting an annotation.

Source code in pare/annotation/server.py
24
25
26
27
28
29
class AnnotationRequest(BaseModel):
    """Request body for submitting an annotation."""

    sample_id: str
    decision: TernaryDecision
    gather_context_rationale: str | None = None

AnnotationServer

Server for managing annotation state and serving samples.

Source code in pare/annotation/server.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
class AnnotationServer:
    """Server for managing annotation state and serving samples."""

    def __init__(self, samples_file: Path, annotations_file: Path, annotators_per_sample: int = 2) -> None:
        """Initialize the annotation server.

        Args:
            samples_file: Path to the samples parquet file.
            annotations_file: Path to the annotations CSV file (created if not exists).
            annotators_per_sample: Number of annotations required per sample.
        """
        self.samples_file = samples_file
        self.annotations_file = annotations_file
        self.annotators_per_sample = annotators_per_sample

        # Load samples
        if not samples_file.exists():
            raise FileNotFoundError(f"Samples file not found: {samples_file}. Run 'pare annotation sample' first.")

        self.samples_df = pl.read_parquet(samples_file)
        logger.info(f"Loaded {len(self.samples_df)} samples from {samples_file}")

        # Build sample lookup
        self._samples: dict[str, Sample] = {}
        for row in self.samples_df.iter_rows(named=True):
            sample = Sample(**row)
            self._samples[sample.sample_id] = sample

        # Split tutorial and real samples
        self._tutorial_samples: dict[str, Sample] = {sid: s for sid, s in self._samples.items() if s.tutorial}
        self._real_samples: dict[str, Sample] = {sid: s for sid, s in self._samples.items() if not s.tutorial}
        if self._tutorial_samples:
            logger.info(f"Tutorial: {len(self._tutorial_samples)}, Real: {len(self._real_samples)}")
        else:
            logger.info(f"No tutorial samples found. All {len(self._real_samples)} samples are for annotation.")

        # Initialize annotations file
        self.annotations_file.parent.mkdir(parents=True, exist_ok=True)
        if not self.annotations_file.exists():
            with open(self.annotations_file, "w") as f:
                f.write(Annotation.csv_header())
            logger.info(f"Created annotations file: {self.annotations_file}")

        # In-memory annotation tracking (real samples only)
        self._annotation_counts: dict[str, int] = {}  # sample_id -> count
        self._user_annotations: dict[str, set[str]] = {}  # user_id -> set of sample_ids
        self._lock = threading.Lock()

        # Load existing annotations
        self._load_annotation_state()

        # Tutorial completion tracking
        # Maps annotator_id -> {sample_id: decision} for tutorial samples
        self._tutorial_annotations: dict[str, dict[str, str]] = {}
        self._tutorial_annotations_file = self.annotations_file.parent / f"{self.annotations_file.stem}_tutorial.csv"
        if not self._tutorial_annotations_file.exists():
            with open(self._tutorial_annotations_file, "w") as f:
                f.write(Annotation.csv_header())
        self._load_tutorial_annotation_state()

    def _load_annotation_state(self) -> None:
        """Load existing annotations into memory."""
        if not self.annotations_file.exists():
            return

        try:
            annotations_df = pl.read_csv(self.annotations_file)
            if len(annotations_df) == 0:
                return

            # Count annotations per sample
            for sample_id in annotations_df["sample_id"].unique().to_list():
                count = len(annotations_df.filter(pl.col("sample_id") == sample_id))
                self._annotation_counts[sample_id] = count

            # Track which samples each user has annotated
            for row in annotations_df.iter_rows(named=True):
                user_id = row["annotator_id"]
                sample_id = row["sample_id"]
                if user_id not in self._user_annotations:
                    self._user_annotations[user_id] = set()
                self._user_annotations[user_id].add(sample_id)

            logger.info(
                f"Loaded {len(annotations_df)} existing annotations from {len(self._user_annotations)} annotators"
            )

        except Exception as e:
            logger.warning(f"Failed to load existing annotations: {e}")

    def _load_tutorial_annotation_state(self) -> None:
        """Load existing tutorial annotations into memory."""
        if not self._tutorial_annotations_file.exists():
            return
        try:
            df = pl.read_csv(self._tutorial_annotations_file)
            if len(df) == 0:
                return
            for row in df.iter_rows(named=True):
                annotator_id = row["annotator_id"]
                sample_id = row["sample_id"]
                decision = row["human_decision"]
                if annotator_id not in self._tutorial_annotations:
                    self._tutorial_annotations[annotator_id] = {}
                self._tutorial_annotations[annotator_id][sample_id] = decision
            logger.info(f"Loaded tutorial annotations for {len(self._tutorial_annotations)} annotators")
        except Exception as e:
            logger.warning(f"Failed to load tutorial annotations: {e}")

    def get_sample(self, sample_id: str) -> Sample | None:
        """Get a sample by ID (tutorial or real)."""
        return self._samples.get(sample_id)

    def get_next_sample(self, annotator_id: str) -> Sample | None:
        """Get the next available real (non-tutorial) sample for an annotator.

        Args:
            annotator_id: The annotator's anonymous ID.

        Returns:
            The next real sample to annotate, or None if all done.
        """
        user_done = self._user_annotations.get(annotator_id, set())

        for sample_id, sample in self._real_samples.items():
            if sample_id in user_done:
                continue
            if self._annotation_counts.get(sample_id, 0) >= self.annotators_per_sample:
                continue
            return sample

        return None

    def record_annotation(
        self,
        sample_id: str,
        annotator_id: str,
        human_decision: TernaryDecision,
        gather_context_rationale: str | None = None,
    ) -> bool:
        """Record a real (non-tutorial) annotation.

        Args:
            sample_id: The sample being annotated.
            annotator_id: The annotator's anonymous ID.
            human_decision: The human's accept/reject/gather_context decision.
            gather_context_rationale: Free-text rationale when decision is gather_context.

        Returns:
            True if recorded successfully.

        Raises:
            ValueError: If sample not found or already annotated by this user.
        """
        sample = self.get_sample(sample_id)
        if not sample:
            raise ValueError(f"Sample not found: {sample_id}")

        with self._lock:
            if annotator_id in self._user_annotations and sample_id in self._user_annotations[annotator_id]:
                raise ValueError(f"User {annotator_id} already annotated sample {sample_id}")

            annotation = Annotation.create(
                sample_id=sample_id,
                annotator_id=annotator_id,
                human_decision=human_decision,
                gather_context_rationale=gather_context_rationale,
            )

            with open(self.annotations_file, "a") as f:
                f.write(annotation.to_csv_row())

            self._annotation_counts[sample_id] = self._annotation_counts.get(sample_id, 0) + 1

            if annotator_id not in self._user_annotations:
                self._user_annotations[annotator_id] = set()
            self._user_annotations[annotator_id].add(sample_id)

            logger.info(
                f"Recorded annotation: sample={sample_id[:20]}..., user={annotator_id[:8]}..., decision={human_decision}"
            )

        return True

    def get_progress(self, annotator_id: str) -> dict[str, int]:
        """Get progress statistics for an annotator (real samples only).

        Args:
            annotator_id: The annotator's anonymous ID.

        Returns:
            Dictionary with completed and total counts.
        """
        user_done = self._user_annotations.get(annotator_id, set())

        total = 0
        completed = len(user_done)

        for sample_id in self._real_samples:
            if self._annotation_counts.get(sample_id, 0) < self.annotators_per_sample:
                total += 1

        total = max(total, completed)

        return {
            "completed": completed,
            "total": total,
        }

    def get_overall_stats(self) -> dict[str, Any]:
        """Get overall annotation statistics."""
        total_samples = len(self._real_samples)

        complete_count = sum(1 for count in self._annotation_counts.values() if count >= self.annotators_per_sample)
        in_progress_count = sum(
            1 for count in self._annotation_counts.values() if 0 < count < self.annotators_per_sample
        )
        not_started_count = total_samples - len(self._annotation_counts)

        total_annotations = sum(self._annotation_counts.values())
        unique_annotators = len(self._user_annotations)

        return {
            "total_samples": total_samples,
            "complete": complete_count,
            "in_progress": in_progress_count,
            "not_started": not_started_count,
            "total_annotations": total_annotations,
            "unique_annotators": unique_annotators,
            "annotators_per_sample": self.annotators_per_sample,
        }

    # --- Tutorial methods ---

    def is_tutorial_completed(self, annotator_id: str) -> bool:
        """Check if annotator has completed all tutorial samples.

        Args:
            annotator_id: The annotator's anonymous ID.

        Returns:
            True if all tutorial samples answered or no tutorials configured.
        """
        if not self._tutorial_samples:
            return True
        with self._lock:
            done = self._tutorial_annotations.get(annotator_id, {})
            return len(done) >= len(self._tutorial_samples)

    def get_next_tutorial_sample(self, annotator_id: str) -> Sample | None:
        """Get the next unanswered tutorial sample for an annotator.

        Args:
            annotator_id: The annotator's anonymous ID.

        Returns:
            Next tutorial Sample, or None if all done.
        """
        with self._lock:
            done = self._tutorial_annotations.get(annotator_id, {})
            for sample_id, sample in self._tutorial_samples.items():
                if sample_id not in done:
                    return sample
            return None

    def get_tutorial_summary(self, annotator_id: str) -> dict[str, Any]:
        """Get tutorial completion summary for an annotator.

        Uses in-memory tutorial annotation data (no file I/O).

        Args:
            annotator_id: The annotator's anonymous ID.

        Returns:
            Summary dict with correct, scored_total, total, answered.
        """
        with self._lock:
            decisions = self._tutorial_annotations.get(annotator_id, {})
            correct = 0
            scored_total = 0
            for sample_id, decision in decisions.items():
                sample = self._tutorial_samples.get(sample_id)
                if sample and sample.correct_decision is not None:
                    scored_total += 1
                    if decision == sample.correct_decision:
                        correct += 1
            return {
                "correct": correct,
                "scored_total": scored_total,
                "total": len(self._tutorial_samples),
                "answered": len(decisions),
            }

    def record_tutorial_annotation(
        self,
        sample_id: str,
        annotator_id: str,
        human_decision: TernaryDecision,
        gather_context_rationale: str | None = None,
    ) -> dict[str, Any]:
        """Record a tutorial annotation and return feedback.

        Writes to tutorial_annotations.csv (not annotations.csv).

        Args:
            sample_id: The tutorial sample being annotated.
            annotator_id: The annotator's anonymous ID.
            human_decision: The annotator's decision.
            gather_context_rationale: Free-text rationale when decision is gather_context.

        Returns:
            Feedback dict with correct, correct_decision, and explanation.

        Raises:
            ValueError: If sample not found in tutorial samples.
        """
        sample = self._tutorial_samples.get(sample_id)
        if not sample:
            raise ValueError(f"Tutorial sample not found: {sample_id}")

        annotation = Annotation.create(
            sample_id=sample_id,
            annotator_id=annotator_id,
            human_decision=human_decision,
            gather_context_rationale=gather_context_rationale,
        )

        with self._lock:
            with open(self._tutorial_annotations_file, "a") as f:
                f.write(annotation.to_csv_row())
            if annotator_id not in self._tutorial_annotations:
                self._tutorial_annotations[annotator_id] = {}
            self._tutorial_annotations[annotator_id][sample_id] = human_decision

        correct_decision = sample.correct_decision
        is_correct = human_decision == correct_decision if correct_decision is not None else None

        logger.info(
            f"Tutorial annotation: sample={sample_id[:20]}..., user={annotator_id[:8]}..., "
            f"decision={human_decision}, correct={is_correct}"
        )

        return {
            "correct": is_correct,
            "correct_decision": correct_decision,
            "explanation": sample.explanation or "",
        }

__init__(samples_file, annotations_file, annotators_per_sample=2)

Initialize the annotation server.

Parameters:

Name Type Description Default
samples_file Path

Path to the samples parquet file.

required
annotations_file Path

Path to the annotations CSV file (created if not exists).

required
annotators_per_sample int

Number of annotations required per sample.

2
Source code in pare/annotation/server.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def __init__(self, samples_file: Path, annotations_file: Path, annotators_per_sample: int = 2) -> None:
    """Initialize the annotation server.

    Args:
        samples_file: Path to the samples parquet file.
        annotations_file: Path to the annotations CSV file (created if not exists).
        annotators_per_sample: Number of annotations required per sample.
    """
    self.samples_file = samples_file
    self.annotations_file = annotations_file
    self.annotators_per_sample = annotators_per_sample

    # Load samples
    if not samples_file.exists():
        raise FileNotFoundError(f"Samples file not found: {samples_file}. Run 'pare annotation sample' first.")

    self.samples_df = pl.read_parquet(samples_file)
    logger.info(f"Loaded {len(self.samples_df)} samples from {samples_file}")

    # Build sample lookup
    self._samples: dict[str, Sample] = {}
    for row in self.samples_df.iter_rows(named=True):
        sample = Sample(**row)
        self._samples[sample.sample_id] = sample

    # Split tutorial and real samples
    self._tutorial_samples: dict[str, Sample] = {sid: s for sid, s in self._samples.items() if s.tutorial}
    self._real_samples: dict[str, Sample] = {sid: s for sid, s in self._samples.items() if not s.tutorial}
    if self._tutorial_samples:
        logger.info(f"Tutorial: {len(self._tutorial_samples)}, Real: {len(self._real_samples)}")
    else:
        logger.info(f"No tutorial samples found. All {len(self._real_samples)} samples are for annotation.")

    # Initialize annotations file
    self.annotations_file.parent.mkdir(parents=True, exist_ok=True)
    if not self.annotations_file.exists():
        with open(self.annotations_file, "w") as f:
            f.write(Annotation.csv_header())
        logger.info(f"Created annotations file: {self.annotations_file}")

    # In-memory annotation tracking (real samples only)
    self._annotation_counts: dict[str, int] = {}  # sample_id -> count
    self._user_annotations: dict[str, set[str]] = {}  # user_id -> set of sample_ids
    self._lock = threading.Lock()

    # Load existing annotations
    self._load_annotation_state()

    # Tutorial completion tracking
    # Maps annotator_id -> {sample_id: decision} for tutorial samples
    self._tutorial_annotations: dict[str, dict[str, str]] = {}
    self._tutorial_annotations_file = self.annotations_file.parent / f"{self.annotations_file.stem}_tutorial.csv"
    if not self._tutorial_annotations_file.exists():
        with open(self._tutorial_annotations_file, "w") as f:
            f.write(Annotation.csv_header())
    self._load_tutorial_annotation_state()

get_next_sample(annotator_id)

Get the next available real (non-tutorial) sample for an annotator.

Parameters:

Name Type Description Default
annotator_id str

The annotator's anonymous ID.

required

Returns:

Type Description
Sample | None

The next real sample to annotate, or None if all done.

Source code in pare/annotation/server.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def get_next_sample(self, annotator_id: str) -> Sample | None:
    """Get the next available real (non-tutorial) sample for an annotator.

    Args:
        annotator_id: The annotator's anonymous ID.

    Returns:
        The next real sample to annotate, or None if all done.
    """
    user_done = self._user_annotations.get(annotator_id, set())

    for sample_id, sample in self._real_samples.items():
        if sample_id in user_done:
            continue
        if self._annotation_counts.get(sample_id, 0) >= self.annotators_per_sample:
            continue
        return sample

    return None

get_next_tutorial_sample(annotator_id)

Get the next unanswered tutorial sample for an annotator.

Parameters:

Name Type Description Default
annotator_id str

The annotator's anonymous ID.

required

Returns:

Type Description
Sample | None

Next tutorial Sample, or None if all done.

Source code in pare/annotation/server.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def get_next_tutorial_sample(self, annotator_id: str) -> Sample | None:
    """Get the next unanswered tutorial sample for an annotator.

    Args:
        annotator_id: The annotator's anonymous ID.

    Returns:
        Next tutorial Sample, or None if all done.
    """
    with self._lock:
        done = self._tutorial_annotations.get(annotator_id, {})
        for sample_id, sample in self._tutorial_samples.items():
            if sample_id not in done:
                return sample
        return None

get_overall_stats()

Get overall annotation statistics.

Source code in pare/annotation/server.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def get_overall_stats(self) -> dict[str, Any]:
    """Get overall annotation statistics."""
    total_samples = len(self._real_samples)

    complete_count = sum(1 for count in self._annotation_counts.values() if count >= self.annotators_per_sample)
    in_progress_count = sum(
        1 for count in self._annotation_counts.values() if 0 < count < self.annotators_per_sample
    )
    not_started_count = total_samples - len(self._annotation_counts)

    total_annotations = sum(self._annotation_counts.values())
    unique_annotators = len(self._user_annotations)

    return {
        "total_samples": total_samples,
        "complete": complete_count,
        "in_progress": in_progress_count,
        "not_started": not_started_count,
        "total_annotations": total_annotations,
        "unique_annotators": unique_annotators,
        "annotators_per_sample": self.annotators_per_sample,
    }

get_progress(annotator_id)

Get progress statistics for an annotator (real samples only).

Parameters:

Name Type Description Default
annotator_id str

The annotator's anonymous ID.

required

Returns:

Type Description
dict[str, int]

Dictionary with completed and total counts.

Source code in pare/annotation/server.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def get_progress(self, annotator_id: str) -> dict[str, int]:
    """Get progress statistics for an annotator (real samples only).

    Args:
        annotator_id: The annotator's anonymous ID.

    Returns:
        Dictionary with completed and total counts.
    """
    user_done = self._user_annotations.get(annotator_id, set())

    total = 0
    completed = len(user_done)

    for sample_id in self._real_samples:
        if self._annotation_counts.get(sample_id, 0) < self.annotators_per_sample:
            total += 1

    total = max(total, completed)

    return {
        "completed": completed,
        "total": total,
    }

get_sample(sample_id)

Get a sample by ID (tutorial or real).

Source code in pare/annotation/server.py
141
142
143
def get_sample(self, sample_id: str) -> Sample | None:
    """Get a sample by ID (tutorial or real)."""
    return self._samples.get(sample_id)

get_tutorial_summary(annotator_id)

Get tutorial completion summary for an annotator.

Uses in-memory tutorial annotation data (no file I/O).

Parameters:

Name Type Description Default
annotator_id str

The annotator's anonymous ID.

required

Returns:

Type Description
dict[str, Any]

Summary dict with correct, scored_total, total, answered.

Source code in pare/annotation/server.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def get_tutorial_summary(self, annotator_id: str) -> dict[str, Any]:
    """Get tutorial completion summary for an annotator.

    Uses in-memory tutorial annotation data (no file I/O).

    Args:
        annotator_id: The annotator's anonymous ID.

    Returns:
        Summary dict with correct, scored_total, total, answered.
    """
    with self._lock:
        decisions = self._tutorial_annotations.get(annotator_id, {})
        correct = 0
        scored_total = 0
        for sample_id, decision in decisions.items():
            sample = self._tutorial_samples.get(sample_id)
            if sample and sample.correct_decision is not None:
                scored_total += 1
                if decision == sample.correct_decision:
                    correct += 1
        return {
            "correct": correct,
            "scored_total": scored_total,
            "total": len(self._tutorial_samples),
            "answered": len(decisions),
        }

is_tutorial_completed(annotator_id)

Check if annotator has completed all tutorial samples.

Parameters:

Name Type Description Default
annotator_id str

The annotator's anonymous ID.

required

Returns:

Type Description
bool

True if all tutorial samples answered or no tutorials configured.

Source code in pare/annotation/server.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def is_tutorial_completed(self, annotator_id: str) -> bool:
    """Check if annotator has completed all tutorial samples.

    Args:
        annotator_id: The annotator's anonymous ID.

    Returns:
        True if all tutorial samples answered or no tutorials configured.
    """
    if not self._tutorial_samples:
        return True
    with self._lock:
        done = self._tutorial_annotations.get(annotator_id, {})
        return len(done) >= len(self._tutorial_samples)

record_annotation(sample_id, annotator_id, human_decision, gather_context_rationale=None)

Record a real (non-tutorial) annotation.

Parameters:

Name Type Description Default
sample_id str

The sample being annotated.

required
annotator_id str

The annotator's anonymous ID.

required
human_decision TernaryDecision

The human's accept/reject/gather_context decision.

required
gather_context_rationale str | None

Free-text rationale when decision is gather_context.

None

Returns:

Type Description
bool

True if recorded successfully.

Raises:

Type Description
ValueError

If sample not found or already annotated by this user.

Source code in pare/annotation/server.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def record_annotation(
    self,
    sample_id: str,
    annotator_id: str,
    human_decision: TernaryDecision,
    gather_context_rationale: str | None = None,
) -> bool:
    """Record a real (non-tutorial) annotation.

    Args:
        sample_id: The sample being annotated.
        annotator_id: The annotator's anonymous ID.
        human_decision: The human's accept/reject/gather_context decision.
        gather_context_rationale: Free-text rationale when decision is gather_context.

    Returns:
        True if recorded successfully.

    Raises:
        ValueError: If sample not found or already annotated by this user.
    """
    sample = self.get_sample(sample_id)
    if not sample:
        raise ValueError(f"Sample not found: {sample_id}")

    with self._lock:
        if annotator_id in self._user_annotations and sample_id in self._user_annotations[annotator_id]:
            raise ValueError(f"User {annotator_id} already annotated sample {sample_id}")

        annotation = Annotation.create(
            sample_id=sample_id,
            annotator_id=annotator_id,
            human_decision=human_decision,
            gather_context_rationale=gather_context_rationale,
        )

        with open(self.annotations_file, "a") as f:
            f.write(annotation.to_csv_row())

        self._annotation_counts[sample_id] = self._annotation_counts.get(sample_id, 0) + 1

        if annotator_id not in self._user_annotations:
            self._user_annotations[annotator_id] = set()
        self._user_annotations[annotator_id].add(sample_id)

        logger.info(
            f"Recorded annotation: sample={sample_id[:20]}..., user={annotator_id[:8]}..., decision={human_decision}"
        )

    return True

record_tutorial_annotation(sample_id, annotator_id, human_decision, gather_context_rationale=None)

Record a tutorial annotation and return feedback.

Writes to tutorial_annotations.csv (not annotations.csv).

Parameters:

Name Type Description Default
sample_id str

The tutorial sample being annotated.

required
annotator_id str

The annotator's anonymous ID.

required
human_decision TernaryDecision

The annotator's decision.

required
gather_context_rationale str | None

Free-text rationale when decision is gather_context.

None

Returns:

Type Description
dict[str, Any]

Feedback dict with correct, correct_decision, and explanation.

Raises:

Type Description
ValueError

If sample not found in tutorial samples.

Source code in pare/annotation/server.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
def record_tutorial_annotation(
    self,
    sample_id: str,
    annotator_id: str,
    human_decision: TernaryDecision,
    gather_context_rationale: str | None = None,
) -> dict[str, Any]:
    """Record a tutorial annotation and return feedback.

    Writes to tutorial_annotations.csv (not annotations.csv).

    Args:
        sample_id: The tutorial sample being annotated.
        annotator_id: The annotator's anonymous ID.
        human_decision: The annotator's decision.
        gather_context_rationale: Free-text rationale when decision is gather_context.

    Returns:
        Feedback dict with correct, correct_decision, and explanation.

    Raises:
        ValueError: If sample not found in tutorial samples.
    """
    sample = self._tutorial_samples.get(sample_id)
    if not sample:
        raise ValueError(f"Tutorial sample not found: {sample_id}")

    annotation = Annotation.create(
        sample_id=sample_id,
        annotator_id=annotator_id,
        human_decision=human_decision,
        gather_context_rationale=gather_context_rationale,
    )

    with self._lock:
        with open(self._tutorial_annotations_file, "a") as f:
            f.write(annotation.to_csv_row())
        if annotator_id not in self._tutorial_annotations:
            self._tutorial_annotations[annotator_id] = {}
        self._tutorial_annotations[annotator_id][sample_id] = human_decision

    correct_decision = sample.correct_decision
    is_correct = human_decision == correct_decision if correct_decision is not None else None

    logger.info(
        f"Tutorial annotation: sample={sample_id[:20]}..., user={annotator_id[:8]}..., "
        f"decision={human_decision}, correct={is_correct}"
    )

    return {
        "correct": is_correct,
        "correct_decision": correct_decision,
        "explanation": sample.explanation or "",
    }

create_app(samples_file, annotations_file, annotators_per_sample=2)

Create the FastAPI application.

Parameters:

Name Type Description Default
samples_file Path

Path to the samples parquet file.

required
annotations_file Path

Path to the annotations CSV file.

required
annotators_per_sample int

Number of annotations required per sample.

2

Returns:

Type Description
FastAPI

Configured FastAPI application.

Source code in pare/annotation/server.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
def create_app(samples_file: Path, annotations_file: Path, annotators_per_sample: int = 2) -> FastAPI:  # noqa: C901
    """Create the FastAPI application.

    Args:
        samples_file: Path to the samples parquet file.
        annotations_file: Path to the annotations CSV file.
        annotators_per_sample: Number of annotations required per sample.

    Returns:
        Configured FastAPI application.
    """
    app = FastAPI(title="PARE Annotation Interface")

    # Initialize server
    server = AnnotationServer(samples_file, annotations_file, annotators_per_sample)

    @app.get("/", response_class=HTMLResponse)
    async def index() -> FileResponse:
        """Serve the main annotation UI."""
        return FileResponse(TEMPLATES_DIR / "index.html")

    @app.get("/api/sample")
    async def get_sample(
        x_annotator_id: str = Header(None, alias="X-Annotator-ID"),
    ) -> SampleResponse | dict[str, Any]:
        """Get the next sample for annotation.

        Serves tutorial samples first. After all tutorials are completed,
        serves real annotation samples.
        """
        if not x_annotator_id:
            raise HTTPException(status_code=400, detail="X-Annotator-ID header required")

        # Serve tutorial samples first
        if not server.is_tutorial_completed(x_annotator_id):
            tutorial_sample = server.get_next_tutorial_sample(x_annotator_id)
            if tutorial_sample:
                with server._lock:
                    tutorial_done = len(server._tutorial_annotations.get(x_annotator_id, {}))
                tutorial_total = len(server._tutorial_samples)
                return tutorial_sample.to_api_response(tutorial_done, tutorial_total)

        # Serve real samples
        sample = server.get_next_sample(x_annotator_id)
        progress = server.get_progress(x_annotator_id)

        if not sample:
            return {
                "sample_id": None,
                "message": "You have completed all available samples. Thank you!",
                "progress": progress,
            }

        return sample.to_api_response(progress["completed"], progress["total"])

    @app.post("/api/annotate")
    async def submit_annotation(
        request: AnnotationRequest,
        x_annotator_id: str = Header(None, alias="X-Annotator-ID"),
    ) -> dict[str, Any]:
        """Submit an annotation.

        Handles both tutorial and real samples. Tutorial submissions
        return feedback and write to tutorial_annotations.csv.
        """
        if not x_annotator_id:
            raise HTTPException(status_code=400, detail="X-Annotator-ID header required")

        # Check if this is a tutorial sample
        sample = server.get_sample(request.sample_id)
        if sample and sample.tutorial:
            try:
                feedback = server.record_tutorial_annotation(
                    sample_id=request.sample_id,
                    annotator_id=x_annotator_id,
                    human_decision=request.decision,
                    gather_context_rationale=request.gather_context_rationale,
                )
            except ValueError as e:
                raise HTTPException(status_code=400, detail=str(e)) from e

            # Check if tutorial is now complete
            if server.is_tutorial_completed(x_annotator_id):
                summary = server.get_tutorial_summary(x_annotator_id)
                return {
                    "success": True,
                    "tutorial_feedback": feedback,
                    "tutorial_complete": True,
                    "tutorial_summary": summary,
                    "next_sample": None,
                }

            # Get next tutorial sample
            next_tutorial = server.get_next_tutorial_sample(x_annotator_id)
            next_sample_data = None
            if next_tutorial:
                with server._lock:
                    tutorial_done = len(server._tutorial_annotations.get(x_annotator_id, {}))
                tutorial_total = len(server._tutorial_samples)
                next_sample_data = next_tutorial.to_api_response(
                    tutorial_done,
                    tutorial_total,
                ).model_dump()

            return {
                "success": True,
                "tutorial_feedback": feedback,
                "next_sample": next_sample_data,
            }

        # Real annotation
        try:
            server.record_annotation(
                sample_id=request.sample_id,
                annotator_id=x_annotator_id,
                human_decision=request.decision,
                gather_context_rationale=request.gather_context_rationale,
            )
        except ValueError as e:
            raise HTTPException(status_code=400, detail=str(e)) from e

        next_sample = server.get_next_sample(x_annotator_id)
        progress = server.get_progress(x_annotator_id)

        if not next_sample:
            return {
                "success": True,
                "next_sample": None,
                "message": "You have completed all available samples. Thank you!",
                "progress": progress,
            }

        return {
            "success": True,
            "next_sample": next_sample.to_api_response(progress["completed"], progress["total"]).model_dump(),
        }

    @app.get("/api/progress")
    async def get_progress_endpoint(
        x_annotator_id: str = Header(None, alias="X-Annotator-ID"),
    ) -> dict[str, Any]:
        """Get annotator's progress."""
        if not x_annotator_id:
            raise HTTPException(status_code=400, detail="X-Annotator-ID header required")

        progress = server.get_progress(x_annotator_id)
        total = progress["total"]
        completed = progress["completed"]
        percentage = (completed / total * 100) if total > 0 else 0

        return {
            "completed": completed,
            "total": total,
            "percentage": round(percentage, 1),
        }

    @app.get("/api/stats")
    async def get_stats() -> dict[str, Any]:
        """Get overall annotation statistics."""
        return server.get_overall_stats()

    return app

run_server(samples_file, annotations_file, port=8000, annotators_per_sample=2)

Run the annotation server.

Parameters:

Name Type Description Default
samples_file Path

Path to the samples parquet file.

required
annotations_file Path

Path to the annotations CSV file.

required
port int

Port to run the server on.

8000
annotators_per_sample int

Number of annotations required per sample.

2
Source code in pare/annotation/server.py
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
def run_server(samples_file: Path, annotations_file: Path, port: int = 8000, annotators_per_sample: int = 2) -> None:
    """Run the annotation server.

    Args:
        samples_file: Path to the samples parquet file.
        annotations_file: Path to the annotations CSV file.
        port: Port to run the server on.
        annotators_per_sample: Number of annotations required per sample.
    """
    import uvicorn

    app = create_app(samples_file, annotations_file, annotators_per_sample)
    logger.info(f"Starting annotation server on http://localhost:{port}")
    logger.info(f"Samples: {samples_file}")
    logger.info(f"Annotations: {annotations_file}")

    uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")  # noqa: S104

Metrics and Configuration

Agreement metrics computation for annotation analysis.

Implements comprehensive metrics for measuring alignment between a ML model and multiple human annotators on binary prediction tasks.

Metrics included: 1. Agreement with Majority Vote (accuracy, F1, precision, recall, Cohen's kappa) 2. Soft Label Alignment (cross-entropy, MAE) 3. Average Pairwise Cohen's Kappa (model vs each human) 4. Krippendorff's Alpha (model as k+1 rater) 5. Fleiss' Kappa (human-only baseline) 6. Stratified Analysis by consensus level

argmax_with_tiebreak(counts, labels, seed_str)

Argmax with deterministic tie-breaking using hash of seed string.

Parameters:

Name Type Description Default
counts list[int]

Vote counts per category.

required
labels list[str]

Category labels corresponding to counts.

required
seed_str str

String to use as tie-breaking seed (e.g., sample_id).

required

Returns:

Type Description
str

Label with highest count (ties broken by hash).

Source code in pare/annotation/metrics.py
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
def argmax_with_tiebreak(counts: list[int], labels: list[str], seed_str: str) -> str:
    """Argmax with deterministic tie-breaking using hash of seed string.

    Args:
        counts: Vote counts per category.
        labels: Category labels corresponding to counts.
        seed_str: String to use as tie-breaking seed (e.g., sample_id).

    Returns:
        Label with highest count (ties broken by hash).
    """
    import hashlib

    max_count = max(counts)
    tied_labels = [labels[i] for i, c in enumerate(counts) if c == max_count]

    if len(tied_labels) == 1:
        return tied_labels[0]

    # Tie-break: hash seed_str to get deterministic index
    hash_val = int(hashlib.sha256(seed_str.encode()).hexdigest(), 16)
    return tied_labels[hash_val % len(tied_labels)]

cohens_kappa_multiclass(y1, y2)

Compute Cohen's Kappa between two raters for multiclass labels.

Parameters:

Name Type Description Default
y1 list[str]

First rater's decisions (categorical labels).

required
y2 list[str]

Second rater's decisions (categorical labels).

required

Returns:

Type Description
float | None

Cohen's Kappa value or None if computation not possible.

Source code in pare/annotation/metrics.py
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
def cohens_kappa_multiclass(y1: list[str], y2: list[str]) -> float | None:
    """Compute Cohen's Kappa between two raters for multiclass labels.

    Args:
        y1: First rater's decisions (categorical labels).
        y2: Second rater's decisions (categorical labels).

    Returns:
        Cohen's Kappa value or None if computation not possible.
    """
    if len(y1) != len(y2) or len(y1) == 0:
        return None

    n = len(y1)
    categories = sorted(set(y1) | set(y2))

    # Build confusion matrix
    confusion: dict[tuple[str, str], int] = {(c1, c2): 0 for c1 in categories for c2 in categories}
    for a, b in zip(y1, y2, strict=False):
        confusion[(a, b)] += 1

    # Observed agreement
    p_o = sum(confusion[(c, c)] for c in categories) / n

    # Expected agreement by chance
    p_e = 0.0
    for c in categories:
        p_c1 = sum(confusion[(c, c2)] for c2 in categories) / n
        p_c2 = sum(confusion[(c1, c)] for c1 in categories) / n
        p_e += p_c1 * p_c2

    if p_e == 1:
        return 1.0 if p_o == 1 else 0.0

    kappa = (p_o - p_e) / (1 - p_e)
    return kappa

compute_agreement_metrics(samples_df, annotations_df, n_annotators=2)

Compute comprehensive agreement metrics.

.. deprecated:: Use compute_agreement_metrics_ternary for ternary decisions. Will be removed after UI update.

Parameters:

Name Type Description Default
samples_df DataFrame

DataFrame with sample data including user_agent_decision.

required
annotations_df DataFrame

DataFrame with annotation data.

required
n_annotators int

Number of top annotators to include (by completion count).

2

Returns:

Type Description
dict[str, Any]

Dictionary containing all computed metrics.

Source code in pare/annotation/metrics.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def compute_agreement_metrics(
    samples_df: pl.DataFrame,
    annotations_df: pl.DataFrame,
    n_annotators: int = 2,
) -> dict[str, Any]:
    """Compute comprehensive agreement metrics.

    .. deprecated::
        Use ``compute_agreement_metrics_ternary`` for ternary decisions. Will be removed after UI update.

    Args:
        samples_df: DataFrame with sample data including user_agent_decision.
        annotations_df: DataFrame with annotation data.
        n_annotators: Number of top annotators to include (by completion count).

    Returns:
        Dictionary containing all computed metrics.
    """
    # Select top-n annotators by completion count
    filtered_annotations = _select_top_annotators(annotations_df, n_annotators)

    # Join annotations with samples
    joined = filtered_annotations.join(
        samples_df.select(["sample_id", "user_agent_decision"]),
        on="sample_id",
        how="left",
    )

    # Filter to samples that have annotations from ALL selected annotators
    selected_annotators = filtered_annotations["annotator_id"].unique().to_list()
    sample_annotator_counts = joined.group_by("sample_id").agg(pl.col("annotator_id").n_unique().alias("n_annotators"))
    complete_samples = sample_annotator_counts.filter(pl.col("n_annotators") == len(selected_annotators))[
        "sample_id"
    ].to_list()

    filtered = joined.filter(pl.col("sample_id").is_in(complete_samples))

    # Basic counts
    n_samples = len(complete_samples)
    n_annotations = len(filtered)
    actual_n_annotators = filtered["annotator_id"].n_unique() if len(filtered) > 0 else 0

    if n_samples == 0:
        logger.warning("No samples with complete annotations from selected annotators")
        return _empty_metrics()

    # Human-human agreement (baseline)
    fleiss_kappa_humans = _compute_fleiss_kappa(filtered, include_model=False)

    # Model alignment metrics
    majority_vote_metrics = _compute_majority_vote_metrics(filtered)
    soft_label_metrics = _compute_soft_label_metrics(filtered)
    avg_pairwise_kappa = _compute_avg_pairwise_model_human_kappa(filtered)
    krippendorff_alpha = _compute_krippendorff_alpha(filtered, include_model=True)
    fleiss_kappa_with_model = _compute_fleiss_kappa(filtered, include_model=True)

    # Stratified analysis
    stratified = _compute_stratified_analysis(filtered)

    # Distribution stats
    human_accept_rate = filtered["human_decision"].mean() if len(filtered) > 0 else 0
    agent_accept_rate = (
        samples_df.filter(pl.col("sample_id").is_in(complete_samples))["user_agent_decision"].mean()
        if n_samples > 0
        else 0
    )

    # Per-annotator stats
    per_annotator = _compute_per_annotator_stats(filtered)

    return {
        # Basic counts
        "n_samples": n_samples,
        "n_annotations": n_annotations,
        "n_annotators": actual_n_annotators,
        # Human-human agreement (baseline)
        "fleiss_kappa_humans": fleiss_kappa_humans,
        # Model vs majority vote
        "majority_vote_metrics": majority_vote_metrics,
        # Soft label alignment
        "soft_label_metrics": soft_label_metrics,
        # Average pairwise model-human kappa
        "avg_pairwise_model_human_kappa": avg_pairwise_kappa,
        # Krippendorff's alpha (model as k+1 rater)
        "krippendorff_alpha_with_model": krippendorff_alpha,
        # Fleiss' kappa with model as k+1 rater
        "fleiss_kappa_with_model": fleiss_kappa_with_model,
        # Stratified analysis by consensus level
        "stratified_analysis": stratified,
        # Distribution stats
        "human_accept_rate": human_accept_rate,
        "agent_accept_rate": agent_accept_rate,
        # Per-annotator stats
        "per_annotator_stats": per_annotator,
    }

compute_agreement_metrics_ternary(samples_df, annotations_df, n_annotators=2)

Compute comprehensive agreement metrics for ternary decisions.

Parameters:

Name Type Description Default
samples_df DataFrame

DataFrame with sample data including user_agent_decision (str).

required
annotations_df DataFrame

DataFrame with annotation data (ternary decisions).

required
n_annotators int

Number of top annotators to include (by completion count).

2

Returns:

Type Description
dict[str, Any]

Dictionary containing all computed metrics.

Source code in pare/annotation/metrics.py
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
def compute_agreement_metrics_ternary(
    samples_df: pl.DataFrame,
    annotations_df: pl.DataFrame,
    n_annotators: int = 2,
) -> dict[str, Any]:
    """Compute comprehensive agreement metrics for ternary decisions.

    Args:
        samples_df: DataFrame with sample data including user_agent_decision (str).
        annotations_df: DataFrame with annotation data (ternary decisions).
        n_annotators: Number of top annotators to include (by completion count).

    Returns:
        Dictionary containing all computed metrics.
    """
    # Select top-n annotators by completion count
    filtered_annotations = _select_top_annotators(annotations_df, n_annotators)

    # Join annotations with samples
    joined = filtered_annotations.join(
        samples_df.select(["sample_id", "user_agent_decision"]),
        on="sample_id",
        how="left",
    )

    # Filter to samples that have annotations from ALL selected annotators
    selected_annotators = filtered_annotations["annotator_id"].unique().to_list()
    sample_annotator_counts = joined.group_by("sample_id").agg(pl.col("annotator_id").n_unique().alias("n_annotators"))
    complete_samples = sample_annotator_counts.filter(pl.col("n_annotators") == len(selected_annotators))[
        "sample_id"
    ].to_list()

    filtered = joined.filter(pl.col("sample_id").is_in(complete_samples))

    # Basic counts
    n_samples = len(complete_samples)
    n_annotations = len(filtered)
    actual_n_annotators = filtered["annotator_id"].n_unique() if len(filtered) > 0 else 0

    if n_samples == 0:
        logger.warning("No samples with complete annotations from selected annotators")
        return _empty_metrics_ternary()

    # Human-human agreement (baseline)
    fleiss_kappa_humans = compute_fleiss_kappa_multiclass(filtered, include_model=False)

    # Model alignment metrics
    majority_vote_metrics = _compute_majority_vote_metrics_ternary(filtered)
    avg_pairwise_kappa = _compute_avg_pairwise_model_human_kappa_ternary(filtered)
    krippendorff_alpha = compute_krippendorff_alpha_multiclass(filtered, include_model=True)
    fleiss_kappa_with_model = compute_fleiss_kappa_multiclass(filtered, include_model=True)

    # Distribution stats (per-category rates)
    category_rates = _compute_category_rates_ternary(filtered, samples_df, complete_samples)

    # Per-annotator stats
    per_annotator = _compute_per_annotator_stats_ternary(filtered)

    return {
        # Basic counts
        "n_samples": n_samples,
        "n_annotations": n_annotations,
        "n_annotators": actual_n_annotators,
        # Human-human agreement (baseline)
        "fleiss_kappa_humans": fleiss_kappa_humans,
        # Model vs majority vote
        "majority_vote_metrics": majority_vote_metrics,
        # Average pairwise model-human kappa
        "avg_pairwise_model_human_kappa": avg_pairwise_kappa,
        # Krippendorff's alpha (model as k+1 rater)
        "krippendorff_alpha_with_model": krippendorff_alpha,
        # Fleiss' kappa with model as k+1 rater
        "fleiss_kappa_with_model": fleiss_kappa_with_model,
        # Per-category rates
        "category_rates": category_rates,
        # Per-annotator stats
        "per_annotator_stats": per_annotator,
    }

compute_decision_entropy(soft_labels_df)

Compute average entropy per user model across all samples.

Entropy measures decision consistency: low entropy = consistent decisions, high entropy = uncertain/variable decisions.

Parameters:

Name Type Description Default
soft_labels_df DataFrame

DataFrame from compute_soft_labels_ternary.

required

Returns:

Type Description
dict[str, float]

Dictionary mapping user_model_id to average entropy.

Source code in pare/annotation/metrics.py
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
def compute_decision_entropy(soft_labels_df: pl.DataFrame) -> dict[str, float]:
    """Compute average entropy per user model across all samples.

    Entropy measures decision consistency: low entropy = consistent decisions,
    high entropy = uncertain/variable decisions.

    Args:
        soft_labels_df: DataFrame from compute_soft_labels_ternary.

    Returns:
        Dictionary mapping user_model_id to average entropy.
    """
    # Compute entropy for each sample
    soft_labels_df = soft_labels_df.with_columns([
        pl.struct(["accept_prob", "reject_prob", "gather_context_prob"])
        .map_elements(
            lambda row: -sum(
                p * math.log(p) if p > 0 else 0
                for p in [row["accept_prob"], row["reject_prob"], row["gather_context_prob"]]
            ),
            return_dtype=pl.Float64,
        )
        .alias("entropy")
    ])

    # Average per model
    model_entropies = soft_labels_df.group_by("user_model_id").agg(pl.col("entropy").mean().alias("avg_entropy"))

    return dict(zip(model_entropies["user_model_id"].to_list(), model_entropies["avg_entropy"].to_list(), strict=False))

compute_fleiss_kappa_multiclass(df, *, include_model=False)

Compute Fleiss' Kappa for multiple raters with arbitrary categories.

Parameters:

Name Type Description Default
df DataFrame

DataFrame with annotations.

required
include_model bool

If True, treat model as an additional rater.

False

Returns:

Type Description
float | None

Fleiss' Kappa value or None if computation not possible.

Source code in pare/annotation/metrics.py
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
def compute_fleiss_kappa_multiclass(df: pl.DataFrame, *, include_model: bool = False) -> float | None:  # noqa: C901
    """Compute Fleiss' Kappa for multiple raters with arbitrary categories.

    Args:
        df: DataFrame with annotations.
        include_model: If True, treat model as an additional rater.

    Returns:
        Fleiss' Kappa value or None if computation not possible.
    """
    if len(df) == 0:
        return None

    samples = df["sample_id"].unique().to_list()
    annotators = df["annotator_id"].unique().to_list()

    if len(annotators) < 2 and not include_model:
        return None

    # Determine all categories from the data
    categories = sorted(df["human_decision"].unique().to_list())
    if include_model:
        categories = sorted(set(categories) | set(df["user_agent_decision"].unique().to_list()))

    # Build rating matrix: rows = samples, columns = categories
    ratings_matrix = []
    for sample_id in samples:
        sample_annotations = df.filter(pl.col("sample_id") == sample_id)
        counts = dict.fromkeys(categories, 0)

        # Count human annotations
        for decision in sample_annotations["human_decision"].to_list():
            counts[decision] += 1

        # Add model decision if requested
        if include_model:
            agent_decision = sample_annotations["user_agent_decision"].first()
            if agent_decision in counts:
                counts[agent_decision] += 1

        ratings_matrix.append([counts[cat] for cat in categories])

    if not ratings_matrix:
        return None

    n = len(ratings_matrix)  # number of samples

    # Calculate P_i for each sample (proportion of agreeing pairs)
    p_i_list = []
    for row in ratings_matrix:
        n_j = sum(row)
        if n_j <= 1:
            continue
        p_i = (sum(r * r for r in row) - n_j) / (n_j * (n_j - 1))
        p_i_list.append(p_i)

    if not p_i_list:
        return None

    p_bar = sum(p_i_list) / len(p_i_list)

    # Calculate P_j for each category (proportion of all ratings in that category)
    total_ratings = sum(sum(row) for row in ratings_matrix)
    p_j_list = []
    for j in range(len(categories)):
        category_total = sum(row[j] for row in ratings_matrix)
        p_j_list.append(category_total / total_ratings if total_ratings > 0 else 0)

    # Expected agreement by chance
    p_e = sum(p_j * p_j for p_j in p_j_list)

    if p_e == 1:
        return 1.0 if p_bar == 1 else 0.0

    kappa = (p_bar - p_e) / (1 - p_e)
    return kappa

compute_kl_divergence(model_probs, human_probs)

Compute KL divergence between model and human probability distributions.

KL(human || model) measures how much information is lost when using model distribution to approximate human distribution.

Parameters:

Name Type Description Default
model_probs list[float]

Model probability distribution [p_accept, p_reject, p_gather].

required
human_probs list[float]

Human probability distribution [p_accept, p_reject, p_gather].

required

Returns:

Type Description
float

KL divergence value (non-negative, 0 means identical distributions).

Source code in pare/annotation/metrics.py
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
def compute_kl_divergence(model_probs: list[float], human_probs: list[float]) -> float:
    """Compute KL divergence between model and human probability distributions.

    KL(human || model) measures how much information is lost when using model
    distribution to approximate human distribution.

    Args:
        model_probs: Model probability distribution [p_accept, p_reject, p_gather].
        human_probs: Human probability distribution [p_accept, p_reject, p_gather].

    Returns:
        KL divergence value (non-negative, 0 means identical distributions).
    """
    eps = LOG_EPSILON
    kl = 0.0
    for p_human, p_model in zip(human_probs, model_probs, strict=False):
        if p_human > 0:
            p_model_clipped = max(eps, p_model)
            kl += p_human * math.log(p_human / p_model_clipped)
    return kl

compute_krippendorff_alpha_multiclass(df, *, include_model=True)

Compute Krippendorff's Alpha for nominal data with arbitrary categories.

Parameters:

Name Type Description Default
df DataFrame

DataFrame with annotations.

required
include_model bool

If True, treat model as an additional rater.

True

Returns:

Type Description
float | None

Krippendorff's Alpha value or None if computation not possible.

Source code in pare/annotation/metrics.py
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
def compute_krippendorff_alpha_multiclass(df: pl.DataFrame, *, include_model: bool = True) -> float | None:  # noqa: C901
    """Compute Krippendorff's Alpha for nominal data with arbitrary categories.

    Args:
        df: DataFrame with annotations.
        include_model: If True, treat model as an additional rater.

    Returns:
        Krippendorff's Alpha value or None if computation not possible.
    """
    if len(df) == 0:
        return None

    samples = df["sample_id"].unique().to_list()
    annotators = df["annotator_id"].unique().to_list()

    # Determine all categories
    categories = sorted(df["human_decision"].unique().to_list())
    if include_model:
        categories = sorted(set(categories) | set(df["user_agent_decision"].unique().to_list()))

    # Build reliability data: dict of (rater, sample) -> value
    ratings: dict[tuple[str, str], str] = {}
    for row in df.iter_rows(named=True):
        ratings[(row["annotator_id"], row["sample_id"])] = row["human_decision"]

    # Add model as additional rater if requested
    if include_model:
        model_id = "__MODEL__"
        for row in df.iter_rows(named=True):
            ratings[(model_id, row["sample_id"])] = row["user_agent_decision"]
        all_raters = [*annotators, model_id]
    else:
        all_raters = annotators

    if len(all_raters) < 2:
        return None

    # Build observed coincidence matrix
    o_matrix = {(v1, v2): 0.0 for v1 in categories for v2 in categories}

    for sample_id in samples:
        # Get all ratings for this sample
        sample_ratings = [(r, ratings[(r, sample_id)]) for r in all_raters if (r, sample_id) in ratings]

        n_raters = len(sample_ratings)
        if n_raters < 2:
            continue

        # For each pair of raters, count coincidences
        for i, (_, v1) in enumerate(sample_ratings):
            for j, (_, v2) in enumerate(sample_ratings):
                if i != j:
                    o_matrix[(v1, v2)] += 1.0 / (n_raters - 1)

    # Total number of pairable values
    n_total = sum(o_matrix.values())
    if n_total == 0:
        return None

    # Calculate observed disagreement (for nominal data: disagreement = 1 if different, 0 if same)
    d_o = 0.0
    for v1 in categories:
        for v2 in categories:
            if v1 != v2:
                d_o += o_matrix[(v1, v2)]
    d_o /= n_total

    # Calculate expected disagreement
    n_c = {v: sum(o_matrix[(v, v2)] for v2 in categories) for v in categories}
    n_total_marginal = sum(n_c.values())

    if n_total_marginal <= 1:
        return None

    d_e = 0.0
    for v1 in categories:
        for v2 in categories:
            if v1 != v2:
                d_e += n_c[v1] * n_c[v2]
    d_e /= n_total_marginal * (n_total_marginal - 1)

    if d_e == 0:
        return 1.0 if d_o == 0 else 0.0

    alpha = 1 - d_o / d_e
    return alpha

compute_per_model_agreement_metrics(evaluations_df, annotations_df, n_annotators=2)

Compute agreement metrics per user model from the evaluation dataframe.

.. deprecated:: Use compute_per_model_agreement_metrics_ternary for ternary decisions. Will be removed after UI update.

For each user_model_id in the evaluations, aggregates runs via majority vote and computes agreement metrics against human annotations.

Parameters:

Name Type Description Default
evaluations_df DataFrame

DataFrame from pare annotation evaluate with columns: sample_id, user_model_id, user_agent_decision, run, valid_response.

required
annotations_df DataFrame

DataFrame with human annotation data.

required
n_annotators int

Number of top annotators to include (by completion count).

2

Returns:

Type Description
dict[str, dict[str, Any]]

Dictionary mapping user_model_id to agreement metrics dict.

Source code in pare/annotation/metrics.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def compute_per_model_agreement_metrics(
    evaluations_df: pl.DataFrame,
    annotations_df: pl.DataFrame,
    n_annotators: int = 2,
) -> dict[str, dict[str, Any]]:
    """Compute agreement metrics per user model from the evaluation dataframe.

    .. deprecated::
        Use ``compute_per_model_agreement_metrics_ternary`` for ternary decisions. Will be removed after UI update.

    For each user_model_id in the evaluations, aggregates runs via majority vote
    and computes agreement metrics against human annotations.

    Args:
        evaluations_df: DataFrame from `pare annotation evaluate` with columns:
            sample_id, user_model_id, user_agent_decision, run, valid_response.
        annotations_df: DataFrame with human annotation data.
        n_annotators: Number of top annotators to include (by completion count).

    Returns:
        Dictionary mapping user_model_id to agreement metrics dict.
    """
    user_models = evaluations_df["user_model_id"].unique().to_list()
    results: dict[str, dict[str, Any]] = {}

    for model_id in sorted(user_models):
        model_evals = evaluations_df.filter((pl.col("user_model_id") == model_id) & pl.col("valid_response"))

        # Majority vote across runs for each sample
        majority_votes = model_evals.group_by("sample_id").agg(
            (pl.col("user_agent_decision").mean() >= 0.5).alias("user_agent_decision"),
            pl.col("scenario_id").first().alias("scenario_id"),
        )

        if len(majority_votes) == 0:
            logger.warning(f"No valid evaluations for model {model_id}")
            results[model_id] = _empty_metrics()
            continue

        # Use the existing compute_agreement_metrics with majority-voted decisions as "samples"
        results[model_id] = compute_agreement_metrics(majority_votes, annotations_df, n_annotators)

    return results

compute_per_model_agreement_metrics_ternary(evaluations_df, annotations_df, n_annotators=2)

Compute agreement metrics per user model for ternary decisions.

For each user_model_id in the evaluations, aggregates runs via majority vote (argmax with deterministic tie-breaking) and computes agreement metrics against human annotations.

Parameters:

Name Type Description Default
evaluations_df DataFrame

DataFrame with columns: sample_id, user_model_id, user_agent_decision (str), run, valid_response.

required
annotations_df DataFrame

DataFrame with human annotation data (ternary decisions).

required
n_annotators int

Number of top annotators to include (by completion count).

2

Returns:

Type Description
dict[str, dict[str, Any]]

Dictionary mapping user_model_id to agreement metrics dict.

Source code in pare/annotation/metrics.py
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
def compute_per_model_agreement_metrics_ternary(
    evaluations_df: pl.DataFrame,
    annotations_df: pl.DataFrame,
    n_annotators: int = 2,
) -> dict[str, dict[str, Any]]:
    """Compute agreement metrics per user model for ternary decisions.

    For each user_model_id in the evaluations, aggregates runs via majority vote
    (argmax with deterministic tie-breaking) and computes agreement metrics
    against human annotations.

    Args:
        evaluations_df: DataFrame with columns:
            sample_id, user_model_id, user_agent_decision (str), run, valid_response.
        annotations_df: DataFrame with human annotation data (ternary decisions).
        n_annotators: Number of top annotators to include (by completion count).

    Returns:
        Dictionary mapping user_model_id to agreement metrics dict.
    """
    user_models = evaluations_df["user_model_id"].unique().to_list()
    results: dict[str, dict[str, Any]] = {}

    for model_id in sorted(user_models):
        model_evals = evaluations_df.filter((pl.col("user_model_id") == model_id) & pl.col("valid_response"))

        # Count votes per category for each sample
        vote_counts = model_evals.group_by("sample_id").agg([
            (pl.col("user_agent_decision") == "accept").sum().alias("accept_count"),
            (pl.col("user_agent_decision") == "reject").sum().alias("reject_count"),
            (pl.col("user_agent_decision") == "gather_context").sum().alias("gather_context_count"),
            pl.col("scenario_id").first().alias("scenario_id"),
        ])

        if len(vote_counts) == 0:
            logger.warning(f"No valid evaluations for model {model_id}")
            results[model_id] = _empty_metrics_ternary()
            continue

        # Majority vote via argmax with deterministic tie-breaking
        majority_votes = vote_counts.with_columns([
            pl.struct(["sample_id", "accept_count", "reject_count", "gather_context_count"])
            .map_elements(
                lambda row: argmax_with_tiebreak(
                    [row["accept_count"], row["reject_count"], row["gather_context_count"]],
                    ["accept", "reject", "gather_context"],
                    row["sample_id"],
                ),
                return_dtype=pl.String,
            )
            .alias("user_agent_decision")
        ])

        results[model_id] = compute_agreement_metrics_ternary(
            majority_votes.select(["sample_id", "scenario_id", "user_agent_decision"]),
            annotations_df,
            n_annotators,
        )

    return results

compute_soft_labels_ternary(eval_df)

Compute soft labels from ternary evaluation results.

Groups by (sample_id, user_model_id) and computes raw counts and probabilities for each decision category (accept, reject, gather_context).

Parameters:

Name Type Description Default
eval_df DataFrame

DataFrame with columns sample_id, user_model_id, user_agent_decision, valid_response.

required

Returns:

Type Description
DataFrame

DataFrame with columns: - sample_id - user_model_id - accept_count, reject_count, gather_context_count - accept_prob, reject_prob, gather_context_prob

Source code in pare/annotation/metrics.py
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
def compute_soft_labels_ternary(eval_df: pl.DataFrame) -> pl.DataFrame:
    """Compute soft labels from ternary evaluation results.

    Groups by (sample_id, user_model_id) and computes raw counts and
    probabilities for each decision category (accept, reject, gather_context).

    Args:
        eval_df: DataFrame with columns sample_id, user_model_id, user_agent_decision, valid_response.

    Returns:
        DataFrame with columns:
            - sample_id
            - user_model_id
            - accept_count, reject_count, gather_context_count
            - accept_prob, reject_prob, gather_context_prob
    """
    # Filter to valid responses only
    valid = eval_df.filter(pl.col("valid_response"))

    # Count per category
    soft_labels = valid.group_by(["sample_id", "user_model_id"]).agg([
        (pl.col("user_agent_decision") == "accept").sum().alias("accept_count"),
        (pl.col("user_agent_decision") == "reject").sum().alias("reject_count"),
        (pl.col("user_agent_decision") == "gather_context").sum().alias("gather_context_count"),
    ])

    # Compute probabilities
    soft_labels = soft_labels.with_columns([
        (pl.col("accept_count") + pl.col("reject_count") + pl.col("gather_context_count")).alias("total"),
    ]).with_columns([
        (pl.col("accept_count").cast(pl.Float64) / pl.col("total")).alias("accept_prob"),
        (pl.col("reject_count").cast(pl.Float64) / pl.col("total")).alias("reject_prob"),
        (pl.col("gather_context_count").cast(pl.Float64) / pl.col("total")).alias("gather_context_prob"),
    ])

    return soft_labels.select([
        "sample_id",
        "user_model_id",
        "accept_count",
        "reject_count",
        "gather_context_count",
        "accept_prob",
        "reject_prob",
        "gather_context_prob",
    ])

Configuration helpers for the annotation module.

ensure_extension(path, ext)

Ensure a file path has the correct extension.

Strips any existing extension and replaces with the specified one.

Parameters:

Name Type Description Default
path Path

The file path.

required
ext str

The desired extension (e.g., '.parquet', '.csv').

required

Returns:

Type Description
Path

Path with the correct extension.

Source code in pare/annotation/config.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def ensure_extension(path: Path, ext: str) -> Path:
    """Ensure a file path has the correct extension.

    Strips any existing extension and replaces with the specified one.

    Args:
        path: The file path.
        ext: The desired extension (e.g., '.parquet', '.csv').

    Returns:
        Path with the correct extension.
    """
    if not ext.startswith("."):
        ext = f".{ext}"
    return path.with_suffix(ext)