Source code for oars.algorithms.parallel

import numpy as np
import multiprocessing as mp
from oars.algorithms.helpers import ConvergenceChecker, getWarmPrimal
from time import time

[docs] def parallelAlgorithm(n, data, resolvents, W, Z, warmstartprimal=None, warmstartdual=None, itrs=1001, gamma=0.9, alpha=1.0, vartol=None, checkperiod=1, verbose=False): """Run the frugal resolvent splitting algorithm for W and Z matrices in parallel Args: n (int): the number of resolvents data (list): list containing the problem data for each resolvent resolvents (list): list of :math:`n` resolvent functions W (ndarray): size (n, n) ndarray for the :math:`W` matrix Z (ndarray): size (n, n) ndarray for the :math:`Z` matrix warmstartprimal (ndarray, optional): resolvent.shape ndarray for :math:`x` in v^0 warmstartdual (list, optional): is a list of n ndarrays for :math:`u` which sums to 0 in v^0 itrs (int, optional): the number of iterations gamma (float, optional): parameter in :math:`v^{k+1} = v^k - \\gamma W x^k` alpha (float, optional): the resolvent step size in :math:`x^{k+1} = J_{\\alpha F^i}(y^k)` vartol (float, optional): is the variable tolerance earlyterm (int, optional): the number of variables that must agree to terminate early and solve explicitly for the remaining variables detectcycle (int, optional): the number of iterations to check for cycling verbose (bool, optional): True for verbose output Returns: xbar (ndarray): the solution results (list): list of dictionaries with the results for each node """ L = -np.tril(Z, -1) # Create the queues man = mp.Manager() Queue_Array, Comms_Data = requiredQueues(man, W, L) if vartol is not None: terminate = man.Value('i',0) #man.Event() Queue_Array['terminate'] = [man.Queue() for _ in range(n)] # Create evaluation process evalProcess = mp.Process(target=evaluate, args=(n, Queue_Array['terminate'], terminate, vartol, itrs, checkperiod, verbose)) evalProcess.start() else: terminate = None # Set v0 if warmstartprimal is not None: all_v = getWarmPrimal(warmstartprimal, L) if verbose:print('warmstartprimal', all_v) else: all_v = [0 for _ in range(n)] if warmstartdual is not None: all_v = [all_v[i] + warmstartdual[i] for i in range(n)] if verbose:print('warmstart final', all_v) # Run subproblems in parallel if verbose: print('Starting Parallel Algorithm') t = time() with mp.Pool(processes=n) as p: params = [(i, data[i], resolvents[i], all_v[i], W, L, Comms_Data[i], Queue_Array, gamma, alpha, itrs, terminate, verbose) for i in range(n)] results = p.starmap(subproblem, params) if verbose: alg_time = time()-t print('Parallel Algorithm Loop Time:', alg_time) xbar = np.mean([results[i]['x'] for i in range(n)], axis=0) # Join the evaluation process if terminate is not None: evalProcess.join() xdev = sum(abs(results[i]['x'] - xbar) for i in range(n)) if verbose: results[0]['alg_time'] = alg_time return xbar, results
def requiredQueues(man, W, L): ''' Returns the queues for the given W and L matrices Args: man (multiprocessing manager): the manager W (ndarray): is the n x n W matrix L (ndarray): is the n x nL matrix Returns: Queue_Array (dict): is the dictionary of the queues with keys (i,j) for the queues from i to j Comms_Data (list): is a list of the required comms data for each node The comms data entry for node i is a dictionary with the following keys: - WQ: nodes which feed only W data into node i - up_LQ: nodes which feed only L data into node i - down_LQ: nodes which receive only L data from node i - up_BQ: nodes which feed both W and L data into node i, and node i feeds W back to - down_BQ: nodes which receive W and L data from node i ''' # Get the number of nodes n = W.shape[0] Queue_Array = {} # Queues required by non-zero off diagonal elements of W Comms_Data = [] for i in range(n): WQ = [] up_LQ = [] down_LQ = [] up_BQ = [] down_BQ = [] Comms_Data.append({'WQ':WQ, 'up_LQ':up_LQ, 'down_LQ':down_LQ, 'up_BQ':up_BQ, 'down_BQ':down_BQ}) for i in range(n): comms_i = Comms_Data[i] for j in range(i): comms_j = Comms_Data[j] if not np.isclose(W[i,j],0.0): if (i,j) not in Queue_Array: queue_ij = man.Queue() Queue_Array[i,j] = queue_ij if (j,i) not in Queue_Array: queue_ji = man.Queue() Queue_Array[j,i] = queue_ji if not np.isclose(L[i,j],0.0): comms_i['up_BQ'].append(j) comms_j['down_BQ'].append(i) else: comms_j['WQ'].append(i) comms_i['WQ'].append(j) elif not np.isclose(L[i,j],0.0): if (j,i) not in Queue_Array: queue_ji = man.Queue() Queue_Array[j,i] = queue_ji comms_i['up_LQ'].append(j) comms_j['down_LQ'].append(i) return Queue_Array, Comms_Data def subproblem(i, data, problem_builder, v0, W, L, comms_data, queue, gamma=0.5, alpha=1.0, itrs=501, terminate=None, verbose=False): ''' Solves the parallel subproblem for node i Args: i (int): is the node number data (dict): is a dictionary containing arguments for the problem problem_builder (class): is a prox class for the problem W (ndarray): is the n x n W matrix L (ndarray): is the n x n L matrix comms_data (dict): is a dictionary with the following keys - WQ: nodes which feed only W data into node i - up_LQ: nodes which feed only L data into node i - down_LQ: nodes which receive only L data from node i - up_BQ: nodes which feed both W and L data into node i, and node i feeds W back to - down_BQ: nodes which receive W and L data from node i queue (dict): is the array of queues gamma (float): is the consensus parameter itrs (int): is the number of iterations terminate (multiprocessing value): is the termination value verbose (bool): is a boolean for verbose output Returns: tuple (x, results): x (ndarray): the solution results (dict): is a dictionary with the following keys - x: the solution - v: the consensus variable - log: the log of the problem (if available) ''' # Create the problem resolvent = problem_builder(data) m = resolvent.shape v_temp = np.zeros(m) local_v = v0 # + np.zeros(m) local_r = np.zeros(m) w_value = np.zeros(m) # Iterate over the problem itr = 0 itr_period = itrs//10 while itr < itrs: if terminate is not None and terminate.value != 0: # if verbose: # print('Node', i, 'received terminate value', terminate.value, 'on iteration', itr) if terminate.value < itr: break #terminate.value = itr + 1 itrs = terminate.value if itr % itr_period == 0 and i == 0 and verbose: print(f'Iteration {itr}') # Get data from upstream L queue for k in comms_data['up_LQ']: local_r += L[i,k]*queue[k,i].get() # Pull from the B queues, update r and v_temp for k in comms_data['up_BQ']: temp = queue[k,i].get() local_r += L[i,k]*temp v_temp += W[i,k]*temp # Solve the problem w_value = resolvent.prox(local_v + local_r, alpha) # Put data in downstream queues for k in comms_data['down_LQ']: queue[i,k].put(w_value) for k in comms_data['down_BQ']: queue[i,k].put(w_value) # Put data in upstream W queues for k in comms_data['WQ']: queue[i,k].put(w_value) for k in comms_data['up_BQ']: queue[i,k].put(w_value) # Update v from all W queues for k in comms_data['WQ']: v_temp += W[i,k]*queue[k,i].get() # Update v from all B queues for k in comms_data['down_BQ']: v_temp += W[i,k]*queue[k,i].get() v_update = gamma*(W[i,i]*w_value+v_temp) local_v = local_v - v_update # Terminate if needed if terminate is not None: queue['terminate'][i].put(v_update) # Zero out v_temp without reallocating memory v_temp.fill(0) local_r.fill(0) itr += 1 if hasattr(resolvent, 'log'): return {'x':w_value, 'v':local_v, 'log':resolvent.log} return {'x':w_value, 'v':local_v} def evaluate(n, terminateQueue, terminate, vartol, itrs, checkperiod=1, verbose=False): """ Evaluate the termination conditions and set the terminate value if needed The terminate value is set a number of iterations ahead of the convergence iteration Args: n (int): the number of nodes terminateQueue (list): the list of queues for termination terminate (multiprocessing value): the termination value vartol (float): the variable tolerance itrs (int): the number of iterations checkperiod (int): the number of iterations between checks (not implemented) verbose (bool): True for verbose output """ v = [] for i in range(n): v.append(terminateQueue[i].get()) #n = len(x) # x is just from node 0 varcounter = 0 itr = 0 itrs -= n while itr < itrs: if verbose:print('iteration', itr+1) prev_v = v.copy() for i in range(n): v[i] = terminateQueue[i].get() delta = sum(np.linalg.norm(v[i]) for i in range(n)) if verbose:print("vartol check delta", delta) if delta < vartol: varcounter += 1 if varcounter >= n: terminate.value = itr + 2*n if verbose: print('Converged on vartol on iteration', itr) break else: varcounter = 0 itr += 1