diff --git a/quimb/experimental/belief_propagation/__init__.py b/quimb/experimental/belief_propagation/__init__.py index 6582fe46..64fd25e4 100644 --- a/quimb/experimental/belief_propagation/__init__.py +++ b/quimb/experimental/belief_propagation/__init__.py @@ -59,7 +59,7 @@ - [ ] (HV2BP) hyper, vectorized, 2-norm - [ ] (HL1BP) hyper, lazy, 1-norm - [ ] (HL2BP) hyper, lazy, 2-norm -- [ ] (D1BP) simple, dense, 1-norm +- [x] (D1BP) simple, dense, 1-norm - simple BP for simple tensor networks - [x] (D2BP) simple, dense, 2-norm - this is the standard PEPS BP algorithm - [ ] (V1BP) simple, vectorized, 1-norm - [ ] (V2BP) simple, vectorized, 2-norm @@ -78,14 +78,30 @@ """ from .bp_common import initialize_hyper_messages -from .d2bp import D2BP, contract_d2bp, compress_d2bp, sample_d2bp +from .d1bp import D1BP, contract_d1bp +from .d2bp import D2BP, compress_d2bp, contract_d2bp, sample_d2bp +from .hd1bp import HD1BP, contract_hd1bp, sample_hd1bp +from .hv1bp import HV1BP, contract_hv1bp, sample_hv1bp +from .l1bp import L1BP, contract_l1bp +from .l2bp import L2BP, compress_l2bp, contract_l2bp __all__ = ( - "initialize_hyper_messages", - "D2BP", - "contract_d2bp", "compress_d2bp", - "sample_d2bp", + "compress_l2bp", + "contract_d1bp", + "contract_d2bp", + "contract_hd1bp", + "contract_hv1bp", + "contract_l1bp", + "contract_l2bp", + "D1BP", + "D2BP", "HD1BP", "HV1BP", + "initialize_hyper_messages", + "L1BP", + "L2BP", + "sample_d2bp", + "sample_hd1bp", + "sample_hv1bp", ) diff --git a/quimb/experimental/belief_propagation/bp_common.py b/quimb/experimental/belief_propagation/bp_common.py index bf492208..d2d5cb96 100644 --- a/quimb/experimental/belief_propagation/bp_common.py +++ b/quimb/experimental/belief_propagation/bp_common.py @@ -99,7 +99,11 @@ def run(self, max_iterations=1000, tol=5e-6, info=None, progbar=False): info["rolling_abs_mean_diff"] = rdm.absmeandiff() -def initialize_hyper_messages(tn, fill_fn=None, smudge_factor=1e-12): +def initialize_hyper_messages( + tn, + fill_fn=None, + smudge_factor=1e-12, +): """Initialize messages for belief propagation, this is equivalent to doing a single round of belief propagation with uniform messages. diff --git a/quimb/experimental/belief_propagation/d1bp.py b/quimb/experimental/belief_propagation/d1bp.py new file mode 100644 index 00000000..616cf87a --- /dev/null +++ b/quimb/experimental/belief_propagation/d1bp.py @@ -0,0 +1,316 @@ +"""Belief propagation for standard tensor networks. This: + +- assumes no hyper indices, only standard bonds. +- assumes a single ('dense') tensor per site +- works directly on the '1-norm' i.e. scalar tensor network + +This is the simplest version of belief propagation, and is useful for +simple investigations. +""" + +import autoray as ar + +from quimb.tensor.contraction import array_contract +from quimb.utils import oset + +from .bp_common import ( + BeliefPropagationCommon, + combine_local_contractions, +) +from .hd1bp import ( + compute_all_tensor_messages_tree, +) + + +def initialize_messages(tn, fill_fn=None): + + backend = ar.infer_backend(next(t.data for t in tn)) + _sum = ar.get_lib_fn(backend, "sum") + + messages = {} + for ix, tids in tn.ind_map.items(): + if len(tids) != 2: + continue + tida, tidb = tids + + for tid_from, tid_to in [(tida, tidb), (tidb, tida)]: + t_from = tn.tensor_map[tid_from] + if fill_fn is not None: + d = t_from.ind_size(ix) + m = fill_fn((d,)) + else: + m = array_contract( + arrays=(t_from.data,), + inputs=(tuple(range(t_from.ndim)),), + output=(t_from.inds.index(ix),), + ) + messages[ix, tid_to] = m / _sum(m) + + return messages + + +class D1BP(BeliefPropagationCommon): + """Dense (as in one tensor per site) 1-norm (as in for 'classical' systems) + belief propagation algorithm. Allows message reuse. This version assumes no + hyper indices (i.e. a standard tensor network). This is the simplest + version of belief propagation. + + Parameters + ---------- + tn : TensorNetwork + The tensor network to run BP on. + messages : dict[(str, int), array_like], optional + The initial messages to use, effectively defaults to all ones if not + specified. + damping : float, optional + The damping factor to use, 0.0 means no damping. + update : {'sequential', 'parallel'}, optional + Whether to update messages sequentially or in parallel. + local_convergence : bool, optional + Whether to allow messages to locally converge - i.e. if all their + input messages have converged then stop updating them. + fill_fn : callable, optional + If specified, use this function to fill in the initial messages. + + Attributes + ---------- + tn : TensorNetwork + The target tensor network. + messages : dict[(str, int), array_like] + The current messages. The key is a tuple of the index and tensor id + that the message is being sent to. + key_pairs : dict[(str, int), (str, int)] + A dictionary mapping the key of a message to the key of the message + propagating in the opposite direction. + """ + + def __init__( + self, + tn, + messages=None, + damping=0.0, + update="sequential", + local_convergence=True, + message_init_function=None, + ): + self.tn = tn + self.damping = damping + self.local_convergence = local_convergence + self.update = update + + self.backend = next(t.backend for t in tn) + _abs = ar.get_lib_fn(self.backend, "abs") + _sum = ar.get_lib_fn(self.backend, "sum") + + def _normalize(x): + return x / _sum(x) + + def _distance(x, y): + return _sum(_abs(x - y)) + + self._normalize = _normalize + self._distance = _distance + + if messages is None: + self.messages = initialize_messages(self.tn, message_init_function) + else: + self.messages = messages + + # record which messages touch which tids, for efficient updates + self.touched = oset() + self.key_pairs = {} + for ix, tids in tn.ind_map.items(): + if len(tids) != 2: + continue + tida, tidb = tids + self.key_pairs[ix, tidb] = (ix, tida) + self.key_pairs[ix, tida] = (ix, tidb) + + def iterate(self, tol=5e-6): + if (not self.local_convergence) or (not self.touched): + # assume if asked to iterate that we want to check all messages + self.touched = oset(self.tn.tensor_map) + + ncheck = len(self.touched) + nconv = 0 + max_mdiff = -1.0 + new_touched = oset() + + def _compute_ms(tid): + t = self.tn.tensor_map[tid] + new_ms = compute_all_tensor_messages_tree( + t.data, + [self.messages[ix, tid] for ix in t.inds], + self.backend, + ) + new_ms = [self._normalize(m) for m in new_ms] + new_ks = [self.key_pairs[ix, tid] for ix in t.inds] + + return new_ks, new_ms + + def _update_m(key, data): + nonlocal nconv, max_mdiff + + m = self.messages[key] + if self.damping != 0.0: + data = (1 - self.damping) * data + self.damping * m + + mdiff = float(self._distance(m, data)) + if mdiff > tol: + # mark distination tid for update + new_touched.add(key[1]) + else: + nconv += 1 + + max_mdiff = max(max_mdiff, mdiff) + self.messages[key] = data + + if self.update == "sequential": + # compute each new message and immediately re-insert it + while self.touched: + tid = self.touched.pop() + keys, new_ms = _compute_ms(tid) + for key, data in zip(keys, new_ms): + _update_m(key, data) + + elif self.update == "parallel": + new_data = {} + # compute all new messages + while self.touched: + tid = self.touched.pop() + keys, new_ms = _compute_ms(tid) + for key, data in zip(keys, new_ms): + new_data[key] = data + # insert all new messages + for key, data in new_data.items(): + _update_m(key, data) + + self.touched = new_touched + return nconv, ncheck, max_mdiff + + def normalize_messages(self): + """Normalize all messages such that for each bond ` = 1` and + ` = ` (but in general != 1). + """ + for ix, tids in self.tn.ind_map.items(): + if len(tids) != 2: + continue + tida, tidb = tids + mi = self.messages[ix, tida] + mj = self.messages[ix, tidb] + nij = abs(mi @ mj)**0.5 + nii = (mi @ mi)**0.25 + njj = (mj @ mj)**0.25 + self.messages[ix, tida] = mi / (nij * nii / njj) + self.messages[ix, tidb] = mj / (nij * njj / nii) + + def get_gauged_tn(self): + """Gauge the original TN by inserting the BP-approximated transfer + matrix eigenvectors, which may be complex. The BP-contraction of this + gauged network is then simply the product of zeroth entries of each + tensor. + """ + tng = self.tn.copy() + for ind, tids in self.tn.ind_map.items(): + tida, tidb = tids + ka = (ind, tida) + kb = (ind, tidb) + ma = self.messages[ka] + mb = self.messages[kb] + + el, ev = ar.do('linalg.eig', ar.do('outer', ma, mb)) + k = ar.do('argsort', -ar.do('abs', el)) + ev = ev[:, k] + Uinv = ev + U = ar.do('linalg.inv', ev) + tng._insert_gauge_tids(U, tida, tidb, Uinv) + return tng + + def contract(self, strip_exponent=False): + tvals = [] + for tid, t in self.tn.tensor_map.items(): + arrays = [t.data] + inputs = [tuple(range(t.ndim))] + for i, ix in enumerate(t.inds): + m = self.messages[ix, tid] + arrays.append(m) + inputs.append((i,)) + tvals.append( + array_contract( + arrays=arrays, + inputs=inputs, + output=(), + ) + ) + + mvals = [] + for ix, tids in self.tn.ind_map.items(): + if len(tids) != 2: + continue + tida, tidb = tids + mvals.append( + self.messages[ix, tida] @ self.messages[ix, tidb] + ) + + return combine_local_contractions( + tvals, mvals, self.backend, strip_exponent=strip_exponent + ) + + + +def contract_d1bp( + tn, + max_iterations=1000, + tol=5e-6, + damping=0.0, + update="sequential", + local_convergence=True, + strip_exponent=False, + info=None, + progbar=False, + **contract_opts, +): + """Estimate the contraction of standard tensor network ``tn`` using dense + 1-norm belief propagation. + + Parameters + ---------- + tn : TensorNetwork + The tensor network to contract, it should have no dangling or hyper + indices. + max_iterations : int, optional + The maximum number of iterations to run for. + tol : float, optional + The convergence tolerance for messages. + damping : float, optional + The damping parameter to use, defaults to no damping. + update : {'sequential', 'parallel'}, optional + Whether to update messages sequentially or in parallel. + local_convergence : bool, optional + Whether to allow messages to locally converge - i.e. if all their + input messages have converged then stop updating them. + strip_exponent : bool, optional + Whether to strip the exponent from the final result. If ``True`` + then the returned result is ``(mantissa, exponent)``. + info : dict, optional + If specified, update this dictionary with information about the + belief propagation run. + progbar : bool, optional + Whether to show a progress bar. + """ + bp = D1BP( + tn, + damping=damping, + local_convergence=local_convergence, + update=update, + **contract_opts, + ) + bp.run( + max_iterations=max_iterations, + tol=tol, + info=info, + progbar=progbar, + ) + return bp.contract( + strip_exponent=strip_exponent, + ) diff --git a/quimb/experimental/belief_propagation/d2bp.py b/quimb/experimental/belief_propagation/d2bp.py index bca48e97..3ff31c50 100644 --- a/quimb/experimental/belief_propagation/d2bp.py +++ b/quimb/experimental/belief_propagation/d2bp.py @@ -1,5 +1,7 @@ import autoray as ar + import quimb.tensor as qtn +from quimb.utils import oset from .bp_common import ( BeliefPropagationCommon, @@ -32,6 +34,10 @@ class D2BP(BeliefPropagationCommon): Computed automatically if not specified. optimize : str or PathOptimizer, optional The path optimizer to use when contracting the messages. + damping : float, optional + The damping factor to use, 0.0 means no damping. + update : {'parallel', 'sequential'}, optional + Whether to update all messages in parallel or sequentially. local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. @@ -45,8 +51,9 @@ def __init__( messages=None, output_inds=None, optimize="auto-hq", - local_convergence=True, damping=0.0, + update="sequential", + local_convergence=True, **contract_opts, ): from quimb.tensor.contraction import array_contract_expression @@ -54,8 +61,9 @@ def __init__( self.tn = tn self.contract_opts = contract_opts self.contract_opts.setdefault("optimize", optimize) - self.local_convergence = local_convergence self.damping = damping + self.local_convergence = local_convergence + self.update = update if output_inds is None: self.output_inds = set(self.tn.outer_inds()) @@ -82,7 +90,7 @@ def _distance(x, y): # record which messages touch each others, for efficient updates self.touch_map = {} - self.touched = set() + self.touched = oset() self.exprs = {} # populate any messages @@ -168,40 +176,55 @@ def iterate(self, tol=5e-6): self.touched.update(self.exprs.keys()) ncheck = len(self.touched) - new_messages = {} - while self.touched: - key = self.touched.pop() + nconv = 0 + max_mdiff = -1.0 + new_touched = oset() + + def _compute_m(key): expr, data = self.exprs[key] m = expr(*data[:2], *(self.messages[mkey] for mkey in data[2:])) # enforce hermiticity and normalize - m = m + ar.dag(m) - m = self._normalize(m) + return self._normalize(m + ar.dag(m)) + def _update_m(key, new_m): + nonlocal nconv, max_mdiff + + old_m = self.messages[key] if self.damping > 0.0: - m = self._normalize( - # new message - (1 - self.damping) * m - + - # old message - self.damping * self.messages[key] + new_m = self._normalize( + self.damping * old_m + (1 - self.damping) * new_m ) - - new_messages[key] = m - - # process modified messages - nconv = 0 - max_mdiff = -1.0 - for key, m in new_messages.items(): - mdiff = float(self._distance(m, self.messages[key])) - + try: + mdiff = float(self._distance(old_m, new_m)) + except (TypeError, ValueError): + # handle e.g. lazy arrays + mdiff = float("inf") if mdiff > tol: # mark touching messages for update - self.touched.update(self.touch_map[key]) + new_touched.update(self.touch_map[key]) else: nconv += 1 - max_mdiff = max(max_mdiff, mdiff) - self.messages[key] = m + self.messages[key] = new_m + + if self.update == "parallel": + new_messages = {} + # compute all new messages + while self.touched: + key = self.touched.pop() + new_messages[key] = _compute_m(key) + # insert all new messages + for key, new_m in new_messages.items(): + _update_m(key, new_m) + + elif self.update == "sequential": + # compute each new message and immediately re-insert it + while self.touched: + key = self.touched.pop() + new_m = _compute_m(key) + _update_m(key, new_m) + + self.touched = new_touched return nconv, ncheck, max_mdiff @@ -354,8 +377,9 @@ def contract_d2bp( messages=None, output_inds=None, optimize="auto-hq", - local_convergence=True, damping=0.0, + update="sequential", + local_convergence=True, max_iterations=1000, tol=5e-6, strip_exponent=False, @@ -382,11 +406,13 @@ def contract_d2bp( Computed automatically if not specified. optimize : str or PathOptimizer, optional The path optimizer to use when contracting the messages. + damping : float, optional + The damping parameter to use, defaults to no damping. + update : {'parallel', 'sequential'}, optional + Whether to update all messages in parallel or sequentially. local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. - damping : float, optional - The damping parameter to use, defaults to no damping. strip_exponent : bool, optional Whether to strip the exponent from the final result. If ``True`` then the returned result is ``(mantissa, exponent)``. @@ -407,8 +433,9 @@ def contract_d2bp( messages=messages, output_inds=output_inds, optimize=optimize, - local_convergence=local_convergence, damping=damping, + local_convergence=local_convergence, + update=update, **contract_opts, ) bp.run( @@ -429,8 +456,9 @@ def compress_d2bp( messages=None, output_inds=None, optimize="auto-hq", - local_convergence=True, damping=0.0, + update="sequential", + local_convergence=True, max_iterations=1000, tol=5e-6, inplace=False, @@ -465,6 +493,8 @@ def compress_d2bp( The path optimizer to use when contracting the messages. damping : float, optional The damping parameter to use, defaults to no damping. + update : {'parallel', 'sequential'}, optional + Whether to update all messages in parallel or sequentially. local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. @@ -487,8 +517,9 @@ def compress_d2bp( messages=messages, output_inds=output_inds, optimize=optimize, - local_convergence=local_convergence, damping=damping, + update=update, + local_convergence=local_convergence, **contract_opts, ) bp.run( diff --git a/quimb/experimental/belief_propagation/hd1bp.py b/quimb/experimental/belief_propagation/hd1bp.py index 8f72f8ae..9233bedb 100644 --- a/quimb/experimental/belief_propagation/hd1bp.py +++ b/quimb/experimental/belief_propagation/hd1bp.py @@ -5,6 +5,7 @@ TODO: - [ ] implement 'touching', so that only necessary messages are updated +- [ ] implement sequential update """ import autoray as ar @@ -202,7 +203,7 @@ def iterate_belief_propagation_basic( backend = ar.infer_backend(next(iter(messages.values()))) # _sum = ar.get_lib_fn(backend, "sum") - # nb at small sizes python sum is faster than numpy sum + # n.b. at small sizes python sum is faster than numpy sum _sum = ar.get_lib_fn(backend, "sum") # _max = ar.get_lib_fn(backend, "max") _abs = ar.get_lib_fn(backend, "abs") @@ -290,6 +291,28 @@ def iterate(self, **kwargs): ) return None, None, max_dm + def get_gauged_tn(self): + """Assuming the supplied tensor network has no hyper or dangling + indices, gauge it by inserting the BP-approximated transfer matrix + eigenvectors, which may be complex. The BP-contraction of this gauged + network is then simply the product of zeroth entries of each tensor. + """ + tng = self.tn.copy() + for ind, tids in self.tn.ind_map.items(): + tida, tidb = tids + ka = (ind, tida) + kb = (ind, tidb) + ma = self.messages[ka] + mb = self.messages[kb] + + el, ev = ar.do('linalg.eig', ar.do('outer', ma, mb)) + k = ar.do('argsort', -ar.do('abs', el)) + ev = ev[:, k] + Uinv = ev + U = ar.do('linalg.inv', ev) + tng._insert_gauge_tids(U, tida, tidb, Uinv) + return tng + def contract(self, strip_exponent=False): """Estimate the total contraction, i.e. the exponential of the 'Bethe free entropy'. diff --git a/quimb/experimental/belief_propagation/l1bp.py b/quimb/experimental/belief_propagation/l1bp.py index 3c4d56c6..4822c938 100644 --- a/quimb/experimental/belief_propagation/l1bp.py +++ b/quimb/experimental/belief_propagation/l1bp.py @@ -1,11 +1,12 @@ import autoray as ar import quimb.tensor as qtn +from quimb.utils import oset from .bp_common import ( BeliefPropagationCommon, - create_lazy_community_edge_map, combine_local_contractions, + create_lazy_community_edge_map, ) @@ -23,6 +24,8 @@ class L1BP(BeliefPropagationCommon): these are inferred automatically. damping : float, optional The damping parameter to use, defaults to no damping. + update : {'parallel', 'sequential'}, optional + Whether to update all messages in parallel or sequentially. local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. @@ -37,8 +40,8 @@ def __init__( tn, site_tags=None, damping=0.0, + update="sequential", local_convergence=True, - update="parallel", optimize="auto-hq", message_init_function=None, **contract_opts, @@ -61,7 +64,7 @@ def __init__( self.local_tns, self.touch_map, ) = create_lazy_community_edge_map(tn, site_tags) - self.touched = set() + self.touched = oset() self._abs = ar.get_lib_fn(self.backend, "abs") self._max = ar.get_lib_fn(self.backend, "max") @@ -72,6 +75,12 @@ def __init__( self._norm = ar.get_lib_fn(self.backend, "linalg.norm") def _normalize(x): + + # sx = self._sum(x) + # sphase = sx / self._abs(sx) + # smag = self._norm(x)**0.5 + # return x / (smag * sphase) + return x / self._sum(x) # return x / self._norm(x) # return x / self._max(x) @@ -135,7 +144,7 @@ def iterate(self, tol=5e-6): ncheck = len(self.touched) nconv = 0 max_mdiff = -1.0 - new_touched = set() + new_touched = oset() def _compute_m(key): i, j = key @@ -154,7 +163,10 @@ def _update_m(key, data): tm = self.messages[key] - if self.damping != 0.0: + if callable(self.damping): + damping_m = self.damping() + data = (1 - damping_m) * data + damping_m * tm.data + elif self.damping != 0.0: data = (1 - self.damping) * data + self.damping * tm.data mdiff = float(self._distance(tm.data, data)) @@ -222,6 +234,19 @@ def contract(self, strip_exponent=False): tvals, mvals, self.backend, strip_exponent=strip_exponent ) + def normalize_messages(self): + """Normalize all messages such that for each bond ` = 1` and + ` = ` (but in general != 1). + """ + for i, j in self.edges: + tmi = self.messages[i, j] + tmj = self.messages[j, i] + nij = abs(tmi @ tmj)**0.5 + nii = (tmi @ tmi)**0.25 + njj = (tmj @ tmj)**0.25 + tmi /= (nij * nii / njj) + tmj /= (nij * njj / nii) + def contract_l1bp( tn, @@ -229,8 +254,8 @@ def contract_l1bp( tol=5e-6, site_tags=None, damping=0.0, + update="sequential", local_convergence=True, - update="parallel", optimize="auto-hq", strip_exponent=False, info=None, @@ -253,6 +278,8 @@ def contract_l1bp( automatically. damping : float, optional The damping parameter to use, defaults to no damping. + update : {'parallel', 'sequential'}, optional + Whether to update all messages in parallel or sequentially. local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. diff --git a/quimb/experimental/belief_propagation/l2bp.py b/quimb/experimental/belief_propagation/l2bp.py index 3eb04e18..15ff5ea0 100644 --- a/quimb/experimental/belief_propagation/l2bp.py +++ b/quimb/experimental/belief_propagation/l2bp.py @@ -3,23 +3,47 @@ import autoray as ar import quimb.tensor as qtn +from quimb.utils import oset + from .bp_common import ( BeliefPropagationCommon, - create_lazy_community_edge_map, combine_local_contractions, + create_lazy_community_edge_map, ) class L2BP(BeliefPropagationCommon): - """A simple class to hold all the data for a L2BP run.""" + """Lazy (as in multiple uncontracted tensors per site) 2-norm (as in for + wavefunctions and operators) belief propagation. + + Parameters + ---------- + tn : TensorNetwork + The tensor network to form the 2-norm of and run BP on. + site_tags : sequence of str, optional + The tags identifying the sites in ``tn``, each tag forms a region, + which should not overlap. If the tensor network is structured, then + these are inferred automatically. + damping : float, optional + The damping parameter to use, defaults to no damping. + update : {'parallel', 'sequential'}, optional + Whether to update all messages in parallel or sequentially. + local_convergence : bool, optional + Whether to allow messages to locally converge - i.e. if all their + input messages have converged then stop updating them. + optimize : str or PathOptimizer, optional + The path optimizer to use when contracting the messages. + contract_opts + Other options supplied to ``cotengra.array_contract``. + """ def __init__( self, tn, site_tags=None, damping=0.0, + update="sequential", local_convergence=True, - update="parallel", optimize="auto-hq", **contract_opts, ): @@ -41,7 +65,7 @@ def __init__( self.local_tns, self.touch_map, ) = create_lazy_community_edge_map(tn, site_tags) - self.touched = set() + self.touched = oset() _abs = ar.get_lib_fn(self.backend, "abs") _sum = ar.get_lib_fn(self.backend, "sum") @@ -67,7 +91,7 @@ def _distance(x, y): self.messages = {} for pair, bix in self.edges.items(): - cix = tuple(ix + "*" for ix in bix) + cix = tuple(ix + "_l2bp*" for ix in bix) remapper = dict(zip(bix, cix)) output_inds = cix + bix @@ -107,10 +131,10 @@ def _distance(x, y): for ix in tn_i_right.ind_map: if ix in bix: # bra outputs - remapper[ix] = ix + "**" + remapper[ix] = ix + "_l2bp**" elif ix in outer_bix: # messages connected - remapper[ix] = ix + "*" + remapper[ix] = ix + "_l2bp*" # remaining indices are either internal and will be mangled # or global outer indices and will be contracted directly @@ -130,12 +154,12 @@ def iterate(self, tol=5e-6): ncheck = len(self.touched) nconv = 0 max_mdiff = -1.0 - new_touched = set() + new_touched = oset() def _compute_m(key): i, j = key bix = self.edges[(i, j) if i < j else (j, i)] - cix = tuple(ix + "**" for ix in bix) + cix = tuple(ix + "_l2bp**" for ix in bix) output_inds = cix + bix tn_i_to_j = self.contraction_tns[i, j] @@ -159,7 +183,11 @@ def _update_m(key, data): if self.damping > 0.0: data = (1 - self.damping) * data + self.damping * tm.data - mdiff = float(self._distance(tm.data, data)) + try: + mdiff = float(self._distance(tm.data, data)) + except (TypeError, ValueError): + # handle e.g. lazy arrays + mdiff = float("inf") if mdiff > tol: # mark touching messages for update @@ -191,6 +219,19 @@ def _update_m(key, data): return nconv, ncheck, max_mdiff + def normalize_messages(self): + """Normalize all messages such that for each bond ` = 1` and + ` = ` (but in general != 1). + """ + for i, j in self.edges: + tmi = self.messages[i, j] + tmj = self.messages[j, i] + nij = (tmi @ tmj)**0.5 + nii = (tmi @ tmi)**0.25 + njj = (tmj @ tmj)**0.25 + tmi /= (nij * nii / njj) + tmj /= (nij * njj / nii) + def contract(self, strip_exponent=False): """Estimate the contraction of the norm squared using the current messages. @@ -201,7 +242,7 @@ def contract(self, strip_exponent=False): # disconnected but still appear in local_tns ks = self.neighbors.get(i, ()) bix = [ix for k in ks for ix in self.edges[tuple(sorted((k, i)))]] - bra = ket.H.reindex_({ix: ix + "*" for ix in bix}) + bra = ket.H.reindex_({ix: ix + "_l2bp*" for ix in bix}) tni = qtn.TensorNetwork( ( ket, @@ -253,11 +294,11 @@ def partial_trace( for ix in tn_bra_i.ind_map: if ix == ket_site_ind: # open up the site index - bra_site_ind = ix + "**" + bra_site_ind = ix + "_l2bp**" ind_changes[ix] = bra_site_ind if ix in outer_bix: # attach bra message indices - ind_changes[ix] = ix + "*" + ind_changes[ix] = ix + "_l2bp*" tn_bra_i.reindex_(ind_changes) tn_rho_i &= tn_bra_i @@ -345,6 +386,7 @@ def contract_l2bp( tn, site_tags=None, damping=0.0, + update="sequential", local_convergence=True, optimize="auto-hq", max_iterations=1000, @@ -364,6 +406,8 @@ def contract_l2bp( The tags identifying the sites in ``tn``, each tag forms a region. damping : float, optional The damping parameter to use, defaults to no damping. + update : {'parallel', 'sequential'}, optional + Whether to update all messages in parallel or sequentially. local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. @@ -388,6 +432,7 @@ def contract_l2bp( tn, site_tags=site_tags, damping=damping, + update=update, local_convergence=local_convergence, optimize=optimize, **contract_opts, @@ -410,6 +455,7 @@ def compress_l2bp( tol=5e-6, site_tags=None, damping=0.0, + update="sequential", local_convergence=True, optimize="auto-hq", lazy=False, @@ -441,6 +487,8 @@ def compress_l2bp( automatically. damping : float, optional The damping parameter to use, defaults to no damping. + update : {'parallel', 'sequential'}, optional + Whether to update all messages in parallel or sequentially. local_convergence : bool, optional Whether to allow messages to locally converge - i.e. if all their input messages have converged then stop updating them. @@ -469,6 +517,7 @@ def compress_l2bp( tnc, site_tags=site_tags, damping=damping, + update=update, local_convergence=local_convergence, optimize=optimize, **contract_opts, diff --git a/quimb/experimental/tensor_1d_mpo_gate_methods.py b/quimb/experimental/tensor_1d_mpo_gate_methods.py deleted file mode 100644 index 5c6a2a7f..00000000 --- a/quimb/experimental/tensor_1d_mpo_gate_methods.py +++ /dev/null @@ -1,346 +0,0 @@ -"""Methods for acting with an MPO on an MPS. - - -TODO: - -- [x] density matrix method -- [x] zip-up method -- [ ] implement early compress boundary method -- [ ] find out why projector method is slower than expected -- [ ] left/right compressed optimal projector method - -""" -from quimb.tensor.tensor_core import ( - ensure_dict, - rand_uuid, - tensor_contract, - TensorNetwork, -) - - -def mps_gate_with_mpo_boundary( - self, - mpo, - max_bond, - cutoff=0.0, -): - return mpo.apply(self, compress=True, max_bond=max_bond, cutoff=cutoff) - - -def mps_gate_with_mpo_lazy(self, mpo): - """Apply an MPO to an MPS lazily, i.e. nothing is contracted, but the new - TN object has the same outer indices as the original MPS. - """ - mps_calc = self.copy() - mpo_calc = mpo.copy() - - outerid = self.site_ind_id - innerid = rand_uuid() + "{}" - - mps_calc.site_ind_id = innerid - mpo_calc.lower_ind_id = innerid - mpo_calc.upper_ind_id = outerid - - mps_calc |= mpo_calc - - mps_calc._site_ind_id = outerid - - return mps_calc - - -def mps_gate_with_mpo_fit( - self, - mpo, - max_bond, - cutoff=0.0, - init_guess=None, - **fit_opts, -): - """Fit a MPS to a MPO applied to an MPS using either ALS or autodiff. - - Some nice alternatives to the default fit_opts: - - - method="autodiff" - - method="als", solver="lstsq" - - """ - if cutoff != 0.0: - raise ValueError("cutoff must be zero for fitting") - - target = mps_gate_with_mpo_lazy(self, mpo) - - if init_guess is None: - ansatz = self.copy() - ansatz.expand_bond_dimension_(max_bond) - else: - raise NotImplementedError - - return ansatz.fit_(target, **fit_opts) - - -def mps_gate_with_mpo_projector( - self, - mpo, - max_bond, - cutoff=0.0, - canonize=False, - canonize_opts=None, - inplace=False, - **compress_opts, -): - tn = mps_gate_with_mpo_lazy(self, mpo) - - if canonize: - # precondition - canonize_opts = ensure_dict(canonize_opts) - tn.gauge_all_(**canonize_opts) - - tn_calc = tn.copy() - - for i in range(tn.L - 1): - ltags = (tn.site_tag(i),) - rtags = (tn.site_tag(i + 1),) - - tn_calc.insert_compressor_between_regions_( - ltags, - rtags, - new_ltags=ltags, - new_rtags=rtags, - max_bond=max_bond, - cutoff=cutoff, - insert_into=tn, - bond_ind=self.bond(i, i + 1), - **compress_opts, - ) - - if inplace: - for i in range(tn.L): - ti = self[i] - data = tensor_contract( - *tn[i], output_inds=ti.inds, optimize="auto-hq" - ).data - ti.modify(data=data) - - else: - for i in range(tn.L): - tn.contract_tags_( - tn.site_tag(i), - output_inds=self[i].inds, - optimize="auto-hq", - ) - - tn.view_like_(self) - - return tn - - -def tensor_1d_compress_dm( - self, - max_bond=None, - cutoff=1e-10, - optimize="auto-hq", - normalize=False, - **compress_opts, -): - ket = self.copy() - bra = ket.H - # doing this means forming the norm doesn't do its own mangling - bra.mangle_inner_() - # form the overlapping double layer TN - norm = bra & ket - # open the bra's indices back up - bra.reindex_all_("__b{}") - - # construct dense left environments - left_envs = {} - left_envs[1] = norm.select(0).contract(optimize=optimize, drop_tags=True) - for i in range(2, self.L): - left_envs[i] = tensor_contract( - left_envs[i - 1], - *norm.select(i - 1), - optimize=optimize, - drop_tags=True, - ) - - # build projectors and right environments - Us = [] - right_env_ket = None - right_env_bra = None - for i in range(self.L - 1, 0, -1): - # form the reduced density matrix - rho_tensors = [left_envs[i], *ket.select(i), *bra.select(i)] - left_inds = [ket.site_ind(i)] - right_inds = [bra.site_ind(i)] - if right_env_ket is not None: - rho_tensors.extend([right_env_ket, right_env_bra]) - left_inds.append(f"__kr{i + 1}") - right_inds.append(f"__br{i + 1}") - - # contract and then split it - rhoi = tensor_contract(*rho_tensors, optimize=optimize) - U, s, UH = rhoi.split( - left_inds=left_inds, - right_inds=right_inds, - method="eigh", - max_bond=max_bond, - cutoff=cutoff, - get="tensors", - absorb=None, - **compress_opts, - ) - - # turn bond into 'virtual right' indices - (bix,) = s.inds - U.reindex_({bix: f"__kr{i}"}) - UH.reindex_({bix: f"__br{i}"}) - Us.append(U) - - # attach the unitaries to the right environments and contract - right_ket_tensors = [*ket.select(i), U.H] - right_bra_tensors = [*bra.select(i), UH.H] - if right_env_ket is not None: - right_ket_tensors.append(right_env_ket) - right_bra_tensors.append(right_env_bra) - - right_env_ket = tensor_contract( - *right_ket_tensors, optimize=optimize, drop_tags=True - ) - # TODO: could compute this just as conjugated and relabelled ket env - right_env_bra = tensor_contract( - *right_bra_tensors, optimize=optimize, drop_tags=True - ) - - # form the final site - U0 = tensor_contract(*ket.select(0), right_env_ket, optimize=optimize) - - if normalize: - # in right canonical form already - U0.normalize_() - - new = TensorNetwork([U0] + Us[::-1]) - # cast as whatever the input was e.g. MPS - new.view_like_(self) - # this puts the array indices in canonical order - new.permute_arrays() - - return new - - -def mps_gate_with_mpo_dm( - mps, - mpo, - max_bond=None, - cutoff=1e-10, - **compress_opts, -): - """Gate this MPS with an MPO, using the density matrix compression method. - - Parameters - ---------- - mps : MatrixProductState - The MPS to gate. - mpo : MatrixProductOperator - The MPO to gate with. - max_bond : int, optional - The maximum bond dimension to keep when compressing the double layer - tensor network, if any. - cutoff : float, optional - The truncation error to use when compressing the double layer tensor - network, if any. - compress_opts - Supplied to :func:`~quimb.tensor.tensor_split`. - """ - # form the double layer tensor network - target = mps_gate_with_mpo_lazy(mps, mpo) - - # directly compress it without first contracting site-wise - return tensor_1d_compress_dm(target, max_bond, cutoff, **compress_opts) - - -def mps_gate_with_mpo_zipup( - mps, - mpo, - max_bond=None, - cutoff=1e-10, - canonize=True, - optimize="auto-hq", - **compress_opts, -): - """ - "Minimally Entangled Typical Thermal State Algorithms", E.M. Stoudenmire & - Steven R. White (https://arxiv.org/abs/1002.1305). - """ - mps = mps.copy() - mpo = mpo.copy() - - if canonize: - # put in 'pseudo' right canonical form: - # - # │ │ │ │ │ │ │ │ │ │ - # ○─◀─◀─◀─◀─◀─◀─◀─◀─◀ MPO - # │ │ │ │ │ │ │ │ │ │ - # ○─◀─◀─◀─◀─◀─◀─◀─◀─◀ MPS - # - mps.right_canonize() - mpo.right_canonize() - - # form double layer - tn = mps_gate_with_mpo_lazy(mps, mpo) - - # zip along the bonds - Us = [] - bix = None - sVH = None - for i in range(tn.L - 1): - # sVH - # │ │ │ │ │ │ │ │ │ │ │ - # ▶═▶═▶═▶══□──◀─◀─◀─◀─◀─◀─◀ - # : ╲ │ │ │ │ │ │ │ - # max_bond ◀─◀─◀─◀─◀─◀─◀ - # i - # .... contract - if sVH is None: - # first site - C = tn.select(i).contract(optimize=optimize) - else: - C = (sVH | tn.select(i)).contract(optimize=optimize) - # i - # │ │ │ │ │ │ │ │ │ │ │ - # ▶═▶═▶═▶════□──◀─◀─◀─◀─◀─◀ - # : ╲ │ │ │ │ │ │ - # bix : ◀─◀─◀─◀─◀─◀ - # split - left_inds = [mps.site_ind(i)] - if bix is not None: - left_inds.append(bix) - - # the new bond index - bix = rand_uuid() - - U, sVH = C.split( - left_inds, - max_bond=max_bond, - cutoff=cutoff, - absorb='right', - bond_ind=bix, - get='tensors', - **compress_opts, - ) - sVH.drop_tags() - Us.append(U) - # i - # │ │ │ │ │ │ │ │ │ │ │ - # ▶═▶═▶═▶══▶═□──◀─◀─◀─◀─◀─◀ - # ╲ │ │ │ │ │ │ - # : : ◀─◀─◀─◀─◀─◀ - # U sVH - - Us.append((sVH | tn.select(tn.L - 1)).contract(optimize=optimize)) - - new = TensorNetwork(Us) - # cast as whatever the input was e.g. MPS - new.view_like_(mps) - # this puts the array indices in canonical order - new.permute_arrays() - - return new diff --git a/quimb/tensor/tensor_arbgeom_compress.py b/quimb/tensor/tensor_arbgeom_compress.py index 189204dc..0d3580f9 100644 --- a/quimb/tensor/tensor_arbgeom_compress.py +++ b/quimb/tensor/tensor_arbgeom_compress.py @@ -366,7 +366,7 @@ def tensor_network_ag_compress_l2bp( canonize=True, damping=0.0, local_convergence=True, - update="parallel", + update="sequential", optimize="auto-hq", inplace=False, **compress_opts, diff --git a/quimb/utils.py b/quimb/utils.py index 27659f6e..87ed368d 100644 --- a/quimb/utils.py +++ b/quimb/utils.py @@ -1,11 +1,10 @@ -"""Misc utility functions. -""" +"""Misc utility functions.""" + +import collections import functools import itertools -import collections from importlib.util import find_spec - try: import cytoolz @@ -373,7 +372,13 @@ def clear(self): def update(self, *others): for o in others: - self._d.update(o._d) + try: + # oset + self._d.update(o._d) + except AttributeError: + # iterable + for k in o: + self._d[k] = None def union(self, *others): u = self.copy() @@ -419,6 +424,8 @@ def popleft(self): def popright(self): return self._d.popitem()[0] + pop = popright + def __eq__(self, other): if isinstance(other, oset): return self._d == other._d diff --git a/tests/test_tensor/test_belief_propagation/test_d1bp.py b/tests/test_tensor/test_belief_propagation/test_d1bp.py new file mode 100644 index 00000000..30347ca9 --- /dev/null +++ b/tests/test_tensor/test_belief_propagation/test_d1bp.py @@ -0,0 +1,39 @@ +import pytest + +import quimb as qu +import quimb.tensor as qtn +from quimb.experimental.belief_propagation import ( + D1BP, + contract_d1bp, +) + + +def test_contract_tree_exact(): + tn = qtn.TN_rand_tree(20, 3) + Z = tn.contract() + info = {} + Z_bp = contract_d1bp(tn, info=info, progbar=True) + assert info["converged"] + assert Z == pytest.approx(Z_bp, rel=1e-12) + + +@pytest.mark.parametrize("damping", [0.0, 0.1]) +def test_contract_normal(damping): + tn = qtn.TN2D_from_fill_fn(lambda s: qu.randn(s, dist="uniform"), 6, 6, 2) + Z = tn.contract() + info = {} + Z_bp = contract_d1bp(tn, damping=damping, info=info, progbar=True) + assert info["converged"] + assert Z == pytest.approx(Z_bp, rel=1e-1) + + +def test_get_gauged_tn(): + tn = qtn.TN2D_from_fill_fn(lambda s: qu.randn(s, dist="uniform"), 6, 6, 2) + Z = tn.contract() + bp = D1BP(tn) + bp.run() + Zbp = bp.contract() + assert Z == pytest.approx(Zbp, rel=1e-1) + tn_gauged = bp.get_gauged_tn() + Zg = qu.prod(array.item(0) for array in tn_gauged.arrays) + assert Z == pytest.approx(Zg, rel=1e-1) diff --git a/tests/test_tensor/test_belief_propagation/test_hd1bp.py b/tests/test_tensor/test_belief_propagation/test_hd1bp.py index 57b62327..1d8de8d8 100644 --- a/tests/test_tensor/test_belief_propagation/test_hd1bp.py +++ b/tests/test_tensor/test_belief_propagation/test_hd1bp.py @@ -3,6 +3,7 @@ import quimb as qu import quimb.tensor as qtn from quimb.experimental.belief_propagation.hd1bp import ( + HD1BP, contract_hd1bp, sample_hd1bp, ) @@ -49,3 +50,15 @@ def test_sample(damping): assert tn_config.num_indices == 0 assert tn_config.contract() == pytest.approx(1.0) assert 0.0 < omega < 1.0 + + +def test_get_gauged_tn(): + tn = qtn.TN2D_from_fill_fn(lambda s: qu.randn(s, dist="uniform"), 6, 6, 2) + Z = tn.contract() + bp = HD1BP(tn) + bp.run() + Zbp = bp.contract() + assert Z == pytest.approx(Zbp, rel=1e-1) + tn_gauged = bp.get_gauged_tn() + Zg = qu.prod(array.item(0) for array in tn_gauged.arrays) + assert Z == pytest.approx(Zg, rel=1e-1)