diff --git a/dill/_dill.py b/dill/_dill.py index 33fabb74..09e57cfa 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -185,6 +185,8 @@ def get_file_type(*args, **kwargs): import inspect import typing +from pickle import GLOBAL, EMPTY_DICT, MARK, DICT, SETITEM + ### Shims for different versions of Python and dill class Sentinel(object): @@ -357,6 +359,7 @@ def __init__(self, file, *args, **kwds): self._recurse = settings['recurse'] if _recurse is None else _recurse self._postproc = OrderedDict() self._file = file + self._globals_cache = {} def save(self, obj, save_persistent_id=True): # numpy hack @@ -1182,12 +1185,13 @@ def _repr_dict(obj): @register(dict) def save_module_dict(pickler, obj): - if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and \ + _is_dill = is_dill(pickler, child=False) + if _is_dill and obj == pickler._main.__dict__ and \ not (pickler._session and pickler._first_pass): logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8')) logger.trace(pickler, "# D1") - elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__): + elif (not _is_dill) and (obj == _main_module.__dict__): logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general? logger.trace(pickler, "# D3") @@ -1197,12 +1201,37 @@ def save_module_dict(pickler, obj): logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj pickler.write(bytes('c%s\n__dict__\n' % obj['__name__'], 'UTF-8')) logger.trace(pickler, "# D4") + elif _is_dill and id(obj) in pickler._globals_cache: + logger.trace(pickler, "D5: %s", _repr_dict(obj)) # obj + # This is a globals dictionary that was partially copied, but not fully saved. + # Save the dictionary again to ensure that everything is there. + globs_copy = pickler._globals_cache[id(obj)] + pickler.write(pickler.get(pickler.memo[id(globs_copy)][0])) + pickler._batch_setitems(iter(obj.items())) + del pickler._globals_cache[id(obj)] + pickler.memo[id(obj)] = (pickler.memo.pop(id(globs_copy))[0], obj) + logger.trace(pickler, "# D5") else: logger.trace(pickler, "D2: %s", _repr_dict(obj)) # obj - if is_dill(pickler, child=False) and pickler._session: + if _is_dill and pickler._session: # we only care about session the first pass thru pickler._first_pass = False - StockPickler.save_dict(pickler, obj) + + # IMPORTANT: update the following code whenever save_dict is changed in pickle.py + # StockPickler.save_dict(pickler, obj) + if pickler.bin: + pickler.write(EMPTY_DICT) + else: # proto 0 -- can't use EMPTY_DICT + pickler.write(MARK + DICT) + + pickler.memoize(obj) + # add __name__ first + if '__name__' in obj: + pickler.save('__name__') + pickler.save(obj['__name__']) + pickler.write(SETITEM) + pickler._batch_setitems(obj.items()) + logger.trace(pickler, "# D2") return @@ -1803,45 +1832,63 @@ def save_function(pickler, obj): logger.trace(pickler, "F1: %s", obj) _recurse = getattr(pickler, '_recurse', None) _postproc = getattr(pickler, '_postproc', None) - _main_modified = getattr(pickler, '_main_modified', None) - _original_main = getattr(pickler, '_original_main', __builtin__)#'None' + _original_main = getattr(pickler, '_original_main', None) + _globals_cache = getattr(pickler, '_globals_cache', None) postproc_list = [] - if _recurse: - # recurse to get all globals referred to by obj - from .detect import globalvars - globs_copy = globalvars(obj, recurse=True, builtin=True) - # Add the name of the module to the globs dictionary to prevent - # the duplication of the dictionary. Pickle the unpopulated - # globals dictionary and set the remaining items after the function - # is created to correctly handle recursion. - globs = {'__name__': obj.__module__} - else: - globs_copy = obj.__globals__ + is_memoized = id(obj.__globals__) in pickler.memo + is_modified_main_dict = ( + _original_main is not None + and obj.__globals__ is _original_main.__dict__ + ) + is_module_dict = ( + not (_recurse or is_memoized or is_modified_main_dict) + and obj.__module__ is not None + and obj.__globals__ is getattr(_import_module(obj.__module__, safe=True), '__dict__', None) + ) + if is_modified_main_dict: # If the globals is the __dict__ from the module being saved as a # session, substitute it by the dictionary being actually saved. - if _main_modified and globs_copy is _original_main.__dict__: - globs_copy = getattr(pickler, '_main', _original_main).__dict__ - globs = globs_copy + globs = pickler._main.__dict__ + elif is_memoized or is_module_dict: # If the globals is a module __dict__, do not save it in the pickle. - elif globs_copy is not None and obj.__module__ is not None and \ - getattr(_import_module(obj.__module__, True), '__dict__', None) is globs_copy: - globs = globs_copy + # It is possible that the globals dictionary itself is also being + # pickled directly. + globs = obj.__globals__ + else: + if _recurse: + # recurse to get all globals referred to by obj + from .detect import globalvars + globs_copy = globalvars(obj, recurse=True, builtin=True) else: - globs = {'__name__': obj.__module__} + # function not bound to an importable module + globs_copy = obj.__globals__ - if globs_copy is not None and globs is not globs_copy: - # In the case that the globals are copied, we need to ensure that - # the globals dictionary is updated when all objects in the - # dictionary are already created. - glob_ids = {id(g) for g in globs_copy.values()} - for stack_element in _postproc: - if stack_element in glob_ids: - _postproc[stack_element].append((_setitems, (globs, globs_copy))) - break + # Add the name of the module to the globs dictionary and prevent + # the duplication of the dictionary. Pickle the unpopulated + # globals dictionary and set the remaining items after the function + # is created to correctly handle recursion. + if _globals_cache is not None and obj.__globals__ is not None: + if id(obj.__globals__) not in _globals_cache: + globs = {'__name__': obj.__module__} + _globals_cache[id(obj.__globals__)] = globs + else: + globs = _globals_cache[id(obj.__globals__)] else: - postproc_list.append((_setitems, (globs, globs_copy))) + globs = {'__name__': obj.__module__} + + if globs_copy is not None and globs is not globs_copy: + # In the case that the globals are copied, we need to ensure that + # the globals dictionary is updated when all objects in the + # dictionary are already created. + glob_ids = {id(g) for g in globs_copy.values()} + for stack_element in _postproc: + if stack_element in glob_ids: + _postproc[stack_element].append((_setitems, (globs, globs_copy))) + break + else: + postproc_list.append((_setitems, (globs, globs_copy))) closure = obj.__closure__ state_dict = {} diff --git a/dill/session.py b/dill/session.py index e9843a71..99befe1c 100644 --- a/dill/session.py +++ b/dill/session.py @@ -244,17 +244,18 @@ def dump_module( if filename is None: filename = str(TEMPDIR/'session.pkl') file = open(filename, 'wb') + original_main = main + if refimported: + main = _stash_modules(main) try: pickler = Pickler(file, protocol, **kwds) - pickler._original_main = main - if refimported: - main = _stash_modules(main) pickler._main = main #FIXME: dill.settings are disabled pickler._byref = False # disable pickling by name reference pickler._recurse = False # disable pickling recursion for globals pickler._session = True # is best indicator of when pickling a session pickler._first_pass = True - pickler._main_modified = main is not pickler._original_main + if main is not original_main: + pickler._original_main = original_main pickler.dump(main) finally: if file is not filename: # if newly opened file diff --git a/dill/tests/test_functions.py b/dill/tests/test_functions.py index cceb64cb..6b0cdfd8 100644 --- a/dill/tests/test_functions.py +++ b/dill/tests/test_functions.py @@ -132,7 +132,33 @@ def test_code_object(): except Exception as error: raise Exception("failed to construct code object with format version {}".format(version)) from error + +def test_shared_globals(): + import dill, test_functors as f, sys + + for recurse in False, True: + g, h = dill.copy((f.g, f.h), recurse=recurse) + assert f.g.__globals__ is f.h.__globals__ + assert g.__globals__ is h.__globals__ + assert f.g.__globals__ is g.__globals__ + assert g(1, 2) == h() == 3 + + del sys.modules['test_functors'] + + g, h = dill.copy((f.g, f.h), recurse=recurse) + assert f.g.__globals__ is f.h.__globals__ + assert g.__globals__ is h.__globals__ + assert f.g.__globals__ is not g.__globals__ + assert g(1, 2) == h() == 3 + g1, g, g2 = dill.copy((f.__dict__, f.g, f.g.__globals__), recurse=recurse) + assert g1 is g.__globals__ + assert g1 is g2 + + sys.modules['test_functors'] = f + + if __name__ == '__main__': test_functions() test_issue_510() test_code_object() + test_shared_globals() diff --git a/dill/tests/test_functors.py b/dill/tests/test_functors.py index ffe29dff..ede9b598 100644 --- a/dill/tests/test_functors.py +++ b/dill/tests/test_functors.py @@ -10,17 +10,18 @@ import dill dill.settings['recurse'] = True +x = 3 def f(a, b, c): # without keywords pass def g(a, b, c=2): # with keywords - pass + return h(a=a, b=b, c=c) def h(a=1, b=2, c=3): # without args - pass + return x def test_functools():