Skip to content

Commit

Permalink
Merge pull request #1519 from chenmoneygithub/fix-reset-copy
Browse files Browse the repository at this point in the history
Fix dspy Module's deepcopy
  • Loading branch information
okhat committed Sep 22, 2024
2 parents c1729b2 + fdc942b commit f5773c4
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 11 deletions.
46 changes: 35 additions & 11 deletions dspy/primitives/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,46 @@ def parameters(self):
return [param for _, param in self.named_parameters()]

def deepcopy(self):
return copy.deepcopy(self)
"""Deep copy the module.
This is a tweak to the default python deepcopy that only deep copies `self.parameters()`, and for other
attributes, we just do the shallow copy.
"""
try:
# If the instance itself is copyable, we can just deep copy it.
# Otherwise we will have to create a new instance and copy over the attributes one by one.
return copy.deepcopy(self)
except Exception:
pass

# Create an empty instance.
new_instance = self.__class__.__new__(self.__class__)
# Set attribuetes of the copied instance.
for attr, value in self.__dict__.items():
if isinstance(value, BaseModule):
setattr(new_instance, attr, value.deepcopy())
else:
try:
# Try to deep copy the attribute
setattr(new_instance, attr, copy.deepcopy(value))
except Exception:
try:
# Fallback to shallow copy if deep copy fails
setattr(new_instance, attr, copy.copy(value))
except Exception:
# If even the shallow copy fails, we just copy over the reference.
setattr(new_instance, attr, value)

return new_instance

def reset_copy(self):
obj = copy.deepcopy(self)
"""Deep copy the module and reset all parameters."""
new_instance = self.deepcopy()

for param in obj.parameters():
for param in new_instance.parameters():
param.reset()

return obj
return new_instance

def dump_state(self, save_verbose):
print(self.named_parameters())
Expand All @@ -119,13 +150,6 @@ def dump_state(self, save_verbose):
def load_state(self, state):
for name, param in self.named_parameters():
param.load_state(state[name])
# try:
# param.load_state(state[name])
# except KeyError:
# if name.endswith("._predict"):
# param.load_state(state[name[:-9]])
# else:
# raise

def save(self, path, save_field_meta=False):
with open(path, "w") as f:
Expand Down
48 changes: 48 additions & 0 deletions tests/primitives/test_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import dspy
import threading


def test_deepcopy_basic():
signature = dspy.Signature("q -> a")
cot = dspy.ChainOfThought(signature)
cot_copy = cot.deepcopy()
assert len(cot.parameters()) == len(cot_copy.parameters())
# Parameters should be different objects with the same values.
assert id(cot.parameters()[0]) != id(cot_copy.parameters()[0])
assert cot.parameters()[0].__dict__ == cot_copy.parameters()[0].__dict__


def test_deepcopy_with_uncopyable_modules():
class CustomClass(dspy.Module):
def __init__(self):
self.lock = threading.Lock() # Non-copyable object.
self.cot = dspy.ChainOfThought(dspy.Signature("q -> a"))

model = CustomClass()
model_copy = model.deepcopy()
assert len(model.parameters()) == len(model_copy.parameters())
# The lock should be refer to the same object (shallow copy).
assert id(model.lock) == id(model_copy.lock)
# Parameters should be different objects with the same values.
assert id(model.parameters()[0]) != id(model_copy.parameters()[0])
assert model.parameters()[0].__dict__ == model_copy.parameters()[0].__dict__


def test_deepcopy_with_nested_modules():
class CustomClass1(dspy.Module):
def __init__(self):
self.lock = threading.Lock() # Non-copyable object.
self.cot = dspy.ChainOfThought(dspy.Signature("q -> a"))

class CustomClass2(dspy.Module):
def __init__(self):
self.submodel = CustomClass1()

model = CustomClass2()
model_copy = model.deepcopy()
assert len(model.parameters()) == len(model_copy.parameters())
# The lock should be refer to the same object (shallow copy).
assert id(model.submodel.lock) == id(model_copy.submodel.lock)
# Parameters should be different objects with the same values.
assert id(model.parameters()[0]) != id(model_copy.parameters()[0])
assert model.parameters()[0].__dict__ == model_copy.parameters()[0].__dict__

0 comments on commit f5773c4

Please sign in to comment.