Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Don't Merge] DSPy Debugger #1481

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions dsp/utils/settings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import threading
import json
from contextlib import contextmanager

from dsp.utils.utils import dotdict
Expand Down Expand Up @@ -35,6 +36,7 @@ def __new__(cls):
compiling=False, # TODO: can probably be removed
skip_logprobs=False,
trace=[],
debug_trace =[],
release=0,
bypass_assert=False,
bypass_suggest=False,
Expand Down Expand Up @@ -99,6 +101,14 @@ def context(self, inherit_config=True, **kwargs):

def __repr__(self) -> str:
return repr(self.config)

def dump_info_for_field(self, field:str = "lm", ignores: list = ["history"]) -> str:
component = self.__getattr__(field)
if not component:
return "{}"
details = {attr: getattr(component, attr) for attr in dir(component)
if not attr.startswith('_') and not callable(getattr(component, attr)) and not attr in ignores}
return json.dumps(details)


settings = Settings()
190 changes: 187 additions & 3 deletions dspy/primitives/program.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import magicattr
import os

from dspy.primitives.assertions import *
from dspy.primitives.module import BaseModule


# from dspy.teleprompt.teleprompt import Teleprompter
from pydantic import BaseModel, RootModel
from typing import List, Tuple, Set, Optional, TypeVar, Type, Union, Literal
import inspect
class ProgramMeta(type):
pass
def __new__(cls, name, bases, class_dict):
for attr, value in class_dict.items():
if attr == "forward" and callable(value):
original_method= value
class_dict[attr] = forward_wrapper(original_method,cls,name)
return type.__new__(cls, name, bases, class_dict)
# def __call__(cls, *args, **kwargs):
# obj = super(ProgramMeta, cls).__call__(*args, **kwargs)

Expand All @@ -14,6 +22,29 @@ class ProgramMeta(type):
# obj._program_init_called = True
# return obj

class PredictorDebugInfo(BaseModel):
demos : List[dict]
signature : dict
extended_signature : Optional[dict] = None
type : Literal["PredictorDebugInfo"] = "PredictorDebugInfo"
unique_id : int

class RetrieveDebugInfo(BaseModel):
k : int
type : Literal["RetrieveDebugInfo"] = "RetrieveDebugInfo"
unique_id : int

class ModuleDebugInfo(BaseModel):
unique_id : int
name : str
class_name : str
path : str
line_num : int
parameters: List[Tuple[str, Union[PredictorDebugInfo,RetrieveDebugInfo] ]]
invoked_modules : List[tuple[str, int]]

class ModelDebugInfoGraph(RootModel):
root : List[ModuleDebugInfo]

class Module(BaseModule, metaclass=ProgramMeta):
def _base_init(self):
Expand Down Expand Up @@ -54,7 +85,84 @@ def activate_assertions(self, handler=backtrack_handler, **handler_args):
"""
assert_transform_module(self, handler, **handler_args)
return self

def trace_info(self) ->str :
import json
from pydantic.json import pydantic_encoder
return json.dumps(dsp.settings.debug_trace)


def debug_info(module : BaseModule) -> str:
from dspy.predict.predict import Predict
from dspy.retrieve.retrieve import Retrieve, RetrieveThenRerank
from collections import deque
from collections.abc import Generator
from dspy.predict.parameter import Parameter
import itertools
T = TypeVar('T')
def named_direct_subobjs(obj, type_ : Type[T]) -> Generator[tuple[str, T], None, None]:
# this function is very similar to the named_sub_modules
# but is only at the base level and will not recursively go find
# inside another attribute
queue = deque([])
seen = set()
def add_to_queue(name, item):
if id(item) not in seen:
seen.add(id(item))
queue.append((name, item))
for name, item in obj.__dict__.items():
add_to_queue(f"{name}", item)

while queue:
name, item = queue.popleft()

if isinstance(item, type_):
yield name, item

elif isinstance(item, (list, tuple)):
for i, sub_item in enumerate(item):
add_to_queue(f"{name}[{i}]", sub_item)

elif isinstance(item, dict):
for key, sub_item in item.items():
add_to_queue(f"{name}[{key}]", sub_item)

ls = []
def debug_info_inner(module : BaseModule, module_sets: Set[int], name: str):
unique_id = id(module)
class_name = type(module).__name__
path = os.path.abspath(inspect.getfile(module.__class__))
line = inspect.findsource(module.__class__)[1]
module_sets.add(unique_id)
sub_modules = list(named_direct_subobjs(module, BaseModule))
non_predict_modules = filter(lambda mod: not isinstance(mod[1], Parameter), sub_modules)
submodule_info : List[tuple[str, int]] = []
for sub_module_name, sub_module in non_predict_modules:
if id(sub_modules) in module_sets:
continue
submodule_info.append((sub_module_name, id(sub_module)))
debug_info_inner(sub_module, module_sets, sub_module_name)
parameters = list(named_direct_subobjs(module, Parameter))
parameters_infos: List[Tuple[str, Union[PredictorDebugInfo,RetrieveDebugInfo]]] = []
if isinstance(module, Parameter):
parameters = itertools.chain([("self as predictor", module)], parameters)
for param_name, parameter in parameters:
unique_param_id = id(parameter)*10+1
if isinstance(parameter, Predict):
demos = list(map(lambda demo : demo.toDict(), parameter.demos))
signature = parameter.signature.model_json_schema()
extended_signature = parameter.extended_signature.model_json_schema() if hasattr(parameter, "extended_signature") else None
info = PredictorDebugInfo(demos = demos, signature=signature, extended_signature=extended_signature, unique_id=unique_param_id)
elif isinstance(parameter, Retrieve) or isinstance(parameter, RetrieveThenRerank):
k = parameter.k
info = RetrieveDebugInfo(k=k, unique_id=unique_param_id)
if info:
parameters_infos.append((param_name, info))
ls.append(ModuleDebugInfo(unique_id=unique_id, name=name, class_name=class_name,
path = path, line_num=line, parameters= parameters_infos,
invoked_modules = submodule_info))
debug_info_inner(module, set(),"current module")
return ModelDebugInfoGraph(ls).model_dump_json()
# def __deepcopy__(self, memo):
# # memo is a dict of id's to copies already made during the current call
# # Check if the object is already copied
Expand All @@ -72,6 +180,82 @@ def activate_assertions(self, handler=backtrack_handler, **handler_args):
# print("Done")

# return new_copy


def forward_wrapper(method, cls, name):
random_name = "______dontuse_______"
def inner_wrapper(self,*args, **kwargs):

from dspy.teleprompt import Teleprompter
from dspy.signatures import Signature, SignatureMeta
import random
r = random.randbytes(10)
stack = inspect.stack()
def hash_stack(stack : inspect.FrameInfo):
if random_name in stack.frame.f_locals:
return stack.frame.f_locals[random_name]
else:
r = random.randint(0,2147483647)
stack.frame.f_locals[random_name] = r
return r

current_id = hash_stack(stack[0])
parent_frame_id = None
parent_object_id = None
parent_name = None
def return_caller_or_none(frame_info : inspect.FrameInfo):

caller_locals = frame_info.frame.f_locals
if 'self' not in caller_locals:
return
caller = caller_locals['self']
if isinstance(caller, Module) or isinstance(caller, Teleprompter):
return caller

try:
for i in range(len(stack)):
current_frame_info = stack[i]
caller = return_caller_or_none(current_frame_info)
if caller and caller != self and (current_frame_info.function == "inner_wrapper" or
isinstance(caller, Teleprompter)):
parent_frame_id = hash_stack(current_frame_info)
parent_object_id = id(caller)
parent_name = type(caller).__name__
break
# potential_next_stack = None if i == len(stack) - 1 else stack[i+1]
# potential_next_caller = return_caller_or_none(potential_next_stack)
# if potential_next_caller == self:
# current_id = id(potential_next_stack)
# if caller and caller != self and potential_next_caller != caller:
# parent_frame_id = id(current_frame_info.frame)
# parent_object_id = id(caller)
# parent_name = type(caller).__name__
# break
except Exception as e :
print(e)
pass
import time
result = method(self, *args, **kwargs)
newkargs = {k: v.model_json_schema() if isinstance(v,SignatureMeta) else v for k, v in kwargs.items()}

trace_obj = {
'class_name' : name,
'object_id' : id(self),
'frame_id' : current_id,
'parent_frame_id': parent_frame_id,
'parent_object_id' : parent_object_id,
'parent_name': parent_name,
'args' : args,
'kwargs' : newkargs,
'file': os.path.abspath(inspect.getfile(self.__class__)),
'line' : inspect.findsource(self.__class__)[1],
'time': int(time.time()),
'result' : result.toDict() if hasattr(result, "toDict") else None
}

dsp.settings.debug_trace.append({k: v for k, v in trace_obj.items() if v is not None})
return result
return inner_wrapper


def set_attribute_by_name(obj, name, value):
Expand Down