Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shared namespaces #534

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
64 changes: 52 additions & 12 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def get_file_type(*args, **kwargs):
import dataclasses
import typing

from pickle import GLOBAL
from pickle import GLOBAL, EMPTY_DICT, MARK, DICT, SETITEM


### Shims for different versions of Python and dill
Expand Down Expand Up @@ -366,6 +366,7 @@ def __init__(self, file, *args, **kwds):
self._recurse = settings['recurse'] if _recurse is None else _recurse
self._postproc = OrderedDict()
self._file = file
self._globals_cache = {}

def dump(self, obj): #NOTE: if settings change, need to update attributes
# register if the object is a numpy ufunc
Expand Down Expand Up @@ -1183,12 +1184,13 @@ def _repr_dict(obj):

@register(dict)
def save_module_dict(pickler, obj):
if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and \
_is_dill = is_dill(pickler, child=False)
if _is_dill and obj == pickler._main.__dict__ and \
not (pickler._session and pickler._first_pass):
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
logger.trace(pickler, "# D1")
elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
elif (not _is_dill) and (obj == _main_module.__dict__):
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
logger.trace(pickler, "# D3")
Expand All @@ -1198,12 +1200,37 @@ def save_module_dict(pickler, obj):
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c%s\n__dict__\n' % obj['__name__'], 'UTF-8'))
logger.trace(pickler, "# D4")
elif _is_dill and id(obj) in pickler._globals_cache:
logger.trace(pickler, "D5: %s", _repr_dict(obj)) # obj
# This is a globals dictionary that was partially copied, but not fully saved.
# Save the dictionary again to ensure that everything is there.
globs_copy = pickler._globals_cache[id(obj)]
pickler.write(pickler.get(pickler.memo[id(globs_copy)][0]))
pickler._batch_setitems(iter(obj.items()))
del pickler._globals_cache[id(obj)]
pickler.memo[id(obj)] = (pickler.memo.pop(id(globs_copy))[0], obj)
logger.trace(pickler, "# D5")
else:
logger.trace(pickler, "D2: %s", _repr_dict(obj)) # obj
if is_dill(pickler, child=False) and pickler._session:
if _is_dill and pickler._session:
# we only care about session the first pass thru
pickler._first_pass = False
StockPickler.save_dict(pickler, obj)

# IMPORTANT: update the following code whenever save_dict is changed in pickle.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that the code for save_dict in all supported version of python are is the same (at least from if pickler.bin to pickler.memoize(obj)). The idea here is to insert the code with __name__ to save_dict while not yielding control to the StockPickler, correct? It seems that save_dict hasn't changed as far back as at least 3.1, so I'm not worried about the copy being made here.

# StockPickler.save_dict(pickler, obj)
if pickler.bin:
pickler.write(EMPTY_DICT)
else: # proto 0 -- can't use EMPTY_DICT
pickler.write(MARK + DICT)

pickler.memoize(obj)
# add __name__ first
if '__name__' in obj:
pickler.save('__name__')
pickler.save(obj['__name__'])
pickler.write(SETITEM)
pickler._batch_setitems(obj.items())

logger.trace(pickler, "# D2")
return

Expand Down Expand Up @@ -1792,17 +1819,18 @@ def save_function(pickler, obj):
_postproc = getattr(pickler, '_postproc', None)
_main_modified = getattr(pickler, '_main_modified', None)
_original_main = getattr(pickler, '_original_main', __builtin__)#'None'
_globals_cache = getattr(pickler, '_globals_cache', None)
postproc_list = []
if _recurse:

globs = None
if id(obj.__globals__) in pickler.memo:
# It is possible that the globals dictionary itself is also being
# pickled directly.
globs = globs_copy = obj.__globals__
elif _recurse:
# recurse to get all globals referred to by obj
from .detect import globalvars
globs_copy = globalvars(obj, recurse=True, builtin=True)

# Add the name of the module to the globs dictionary to prevent
# the duplication of the dictionary. Pickle the unpopulated
# globals dictionary and set the remaining items after the function
# is created to correctly handle recursion.
globs = {'__name__': obj.__module__}
else:
globs_copy = obj.__globals__

Expand All @@ -1815,6 +1843,18 @@ def save_function(pickler, obj):
elif globs_copy is not None and obj.__module__ is not None and \
getattr(_import_module(obj.__module__, True), '__dict__', None) is globs_copy:
globs = globs_copy

if globs is None:
# Add the name of the module to the globs dictionary and prevent
# the duplication of the dictionary. Pickle the unpopulated
# globals dictionary and set the remaining items after the function
# is created to correctly handle recursion.
if _globals_cache is not None and obj.__globals__ is not None:
if id(obj.__globals__) not in _globals_cache:
globs = {'__name__': obj.__module__}
_globals_cache[id(obj.__globals__)] = globs
else:
globs = _globals_cache[id(obj.__globals__)]
else:
globs = {'__name__': obj.__module__}

Expand Down
10 changes: 10 additions & 0 deletions dill/tests/_globals_dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# This file is used by test_shared_globals in test_functions.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be possible to use or modify a simple test like test_functors so that file is used as opposed to using an additional file like this one.


x = 3

def h():
return x

def g():
return h()

26 changes: 26 additions & 0 deletions dill/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,33 @@ def test_code_object():
except Exception as error:
raise Exception("failed to construct code object with format version {}".format(version)) from error


def test_shared_globals():
import dill, _globals_dummy as f, sys

for recurse in False, True:
g, h = dill.copy((f.g, f.h), recurse=recurse)
assert f.g.__globals__ is f.h.__globals__
assert g.__globals__ is h.__globals__
assert f.g.__globals__ is g.__globals__
assert g() == h() == 3

del sys.modules['_globals_dummy']

g, h = dill.copy((f.g, f.h), recurse=recurse)
assert f.g.__globals__ is f.h.__globals__
assert g.__globals__ is h.__globals__
assert f.g.__globals__ is not g.__globals__
assert g() == h() == 3
g1, g, g2 = dill.copy((f.__dict__, f.g, f.g.__globals__), recurse=recurse)
assert g1 is g.__globals__
assert g1 is g2

sys.modules['_globals_dummy'] = f


if __name__ == '__main__':
test_functions()
test_issue_510()
test_code_object()
test_shared_globals()