mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-17 02:05:57 +00:00
fix(dispatch): forward session_id into registry.dispatch (#28479)
Both the regular and execute_code dispatch paths forward task_id into registry.dispatch via middleware _dispatch lambdas but silently dropped session_id. Dispatch-layer hooks (e.g. set_enforcement_fn) that correlate calls with the active session received "" for every invocation. Pass session_id=session_id at both _dispatch call sites inside handle_function_call, matching the existing task_id pattern. Hooks already received session_id; this closes the registry.dispatch gap. Rebased onto current main where dispatch is wrapped by run_tool_execution_middleware — the old direct-dispatch sites from #28479 no longer exist. test(dispatch): add tests for session_id forwarding (NousResearch#28479) Covers standard and execute_code paths through the middleware wrapper. Verifies task_id forwarding is not broken by the change.
This commit is contained in:
parent
7aaae7acd0
commit
8d5d36d793
2 changed files with 77 additions and 0 deletions
|
|
@ -1115,6 +1115,7 @@ def handle_function_call(
|
|||
return registry.dispatch(
|
||||
function_name, next_args,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
enabled_tools=sandbox_enabled,
|
||||
)
|
||||
else:
|
||||
|
|
@ -1122,6 +1123,7 @@ def handle_function_call(
|
|||
return registry.dispatch(
|
||||
function_name, next_args,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_task=user_task,
|
||||
)
|
||||
from hermes_cli.middleware import run_tool_execution_middleware
|
||||
|
|
|
|||
75
tests/test_dispatch_session_id.py
Normal file
75
tests/test_dispatch_session_id.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""Tests that handle_function_call forwards session_id into registry.dispatch."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def _make_registry(captured: dict):
|
||||
"""Return a mock registry whose dispatch records the kwargs it receives."""
|
||||
registry = MagicMock()
|
||||
|
||||
def _dispatch(name, args, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return json.dumps({"result": "ok"})
|
||||
|
||||
registry.dispatch.side_effect = _dispatch
|
||||
return registry
|
||||
|
||||
|
||||
class TestSessionIdForwarding:
|
||||
|
||||
def test_standard_path_forwards_session_id(self):
|
||||
"""registry.dispatch receives session_id on the normal tool path."""
|
||||
captured = {}
|
||||
with patch("model_tools.registry", _make_registry(captured)):
|
||||
from model_tools import handle_function_call
|
||||
handle_function_call(
|
||||
"web_search",
|
||||
{"query": "test"},
|
||||
task_id="t1",
|
||||
session_id="sess-abc",
|
||||
skip_pre_tool_call_hook=True,
|
||||
)
|
||||
assert captured.get("session_id") == "sess-abc"
|
||||
|
||||
def test_execute_code_path_forwards_session_id(self):
|
||||
"""registry.dispatch receives session_id on the execute_code path."""
|
||||
captured = {}
|
||||
with patch("model_tools.registry", _make_registry(captured)):
|
||||
from model_tools import handle_function_call
|
||||
handle_function_call(
|
||||
"execute_code",
|
||||
{"code": "print(1)"},
|
||||
task_id="t1",
|
||||
session_id="sess-xyz",
|
||||
skip_pre_tool_call_hook=True,
|
||||
)
|
||||
assert captured.get("session_id") == "sess-xyz"
|
||||
|
||||
def test_session_id_default_is_none(self):
|
||||
"""When session_id is omitted, dispatch receives None."""
|
||||
captured = {}
|
||||
with patch("model_tools.registry", _make_registry(captured)):
|
||||
from model_tools import handle_function_call
|
||||
handle_function_call(
|
||||
"web_search",
|
||||
{"query": "test"},
|
||||
task_id="t1",
|
||||
skip_pre_tool_call_hook=True,
|
||||
)
|
||||
assert "session_id" in captured
|
||||
assert captured["session_id"] is None
|
||||
|
||||
def test_task_id_still_forwarded(self):
|
||||
"""Existing task_id forwarding is not broken by this change."""
|
||||
captured = {}
|
||||
with patch("model_tools.registry", _make_registry(captured)):
|
||||
from model_tools import handle_function_call
|
||||
handle_function_call(
|
||||
"web_search",
|
||||
{"query": "test"},
|
||||
task_id="task-999",
|
||||
session_id="sess-1",
|
||||
skip_pre_tool_call_hook=True,
|
||||
)
|
||||
assert captured.get("task_id") == "task-999"
|
||||
Loading…
Add table
Add a link
Reference in a new issue