#!/usr/bin/env python3
from mpi4py import MPI
import numpy as np
def initialize(comm, n, data, resolvents, W, Z, gamma=0.8, alpha=1.0):
"""
Initialize the resolvents and variables
Args:
n (int): the number of resolvents
data (list): list of dictionaries containing the problem data
resolvents (list): list of uninitialized resolvent classes
W (ndarray, optional): size (n, n) ndarray for the W matrix
Z (ndarray, optional): size (n, n) ndarray for the Z matrix
gamma (float, optional): the consensus parameter
alpha (float, optional): the resolvent scaling parameter
"""
Comms_Data = requiredComms(Z, W)
# Distribute the data
for j in range(1, n):
#print("Node 0 sending data to node", j, flush=True)
comm.send(data[j], dest=j, tag=44) # Data
comm.send(resolvents[j], dest=j, tag=17) # Resolvent
comm.send(Comms_Data[j], dest=j, tag=33) # Comms data
return Comms_Data[0]
# Does not work yet
# Distributed algorithm
[docs]
def distributedAlgorithm(n, data, resolvents, W, Z, warmstartprimal=None, warmstartdual=None, itrs=1001, gamma=0.9, alpha=1.0, vartol=None, verbose=False):
"""
Distributed algorithm for frugal resolvent splitting
Args:
n (int): the number of resolvents
resolvents (list): list of resolvent classes
W (ndarray): size (n, n) ndarray for the W matrix
Z (ndarray): size (n, n) ndarray for the Z matrix
data (list): list containing the problem data for each resolvent
warmstartprimal (ndarray, optional): resolvent.shape ndarray for x in v^0
warmstartdual (list, optional): is a list of n ndarrays for u which sums to 0 in v^0
itrs (int, optional): the number of iterations
gamma (float, optional): step size for the consensus step
alpha (float, optional): the resolvent step size
vartol (float, optional): is the variable tolerance
objtol (float, optional): is the objective tolerance
earlyterm (int, optional): the number of iterations to run before checking for termination
detectcycle (int, optional): the number of iterations to check for a cycle
objective (function, optional): the objective function
verbose (bool, optional): True for verbose output
Returns:
x (ndarray): the solution
results (list): list of dictionaries with the results for each resolvent
"""
# nodes = L.shape[0]
comm = MPI.COMM_WORLD
i = comm.Get_rank()
n_size = comm.Get_size()
if n > n_size - 1:
raise ValueError("Number of nodes is greater than the number of processes")
# nodes = n-1
if i == 0:
initialize(n, data, resolvents, W, Z, gamma, alpha)
# Run subproblems
print("Node 0 running subproblem", flush=True)
#print("Comms data 0", Comms_Data[0], flush=True)
t = time()
x, log = subproblem(i, data[i], resolvents[i], W, Z, Comms_Data[i], comm, gamma, itrs, vartol=1e-5, verbose=True)
print("Time", time() - t)
#timestamp = time()
# with open('logs'+str(i)+'_'+title+'.json', 'w') as f:
# json.dump(log, f)
#w = np.array(m)
results = []
results.append({'x':x})
x_i = np.zeros(x.shape)
for k in range(1, n-1):
comm.Recv(x_i, source=k, tag=0)
results.append({'x':x_i})
x += x_i
xbar = (1/n)*x
#print(w, proj_w, flush=True)
# print("alg val", fullValue(fulldata[-1], proj_w))
#t = time()
return xbar, results
elif i < n-1:
# Receive L and W
#print(f"Node {i} receiving L and W", flush=True)
#L = np.zeros((n-1,n-1))
#W = np.zeros((n-1,n-1))
#comm.Bcast(L, root=0)
#comm.Bcast(W, root=0)
#print(f"Node {i} received L and W", flush=True)
# Receive the data
#data = np.array(m)
data = comm.recv(source=0, tag=44)
res = comm.recv(source=0, tag=17)
comms = comm.recv(source=0, tag=33)
# Run the subproblem
#print(f"Node {i} running subproblem", flush=True)
x, log = subproblem(i, data, res, W, Z, comms, comm, gamma, itrs, vartol=1e-2, verbose=True)
#timestamp = time()
# with open('logs_wta'+str(i)+'_'+title+'.json', 'w') as f:
# json.dump(log, f)
#w = np.array(i)
comm.Send(x, dest=0, tag=0)
elif i == n-1:
#L = np.zeros((n-1,n-1))
#W = np.zeros((n-1,n-1))
#comm.Bcast(L, root=0)
#comm.Bcast(W, root=0)
evaluate(m, comm, vartol=1e-5, itrs=itrs)
def requiredComms(Z, W):
'''
Returns a dictionary of the communications required by the given W and L matrices
Args:
Z (ndarray): the Z matrix
W (ndarray): the W matrix
Returns:
Comms_Data (list): a list of dictionaries with the required comms data for each node
WQ (list): nodes which feed only W data into node i
up_LQ (list): nodes which feed only Z data into node i
down_LQ (list): nodes which receive only Z data from node i
up_BQ (list): nodes which feed both W and Z data into node i, and node i feeds W back to
down_BQ (list): nodes which receive W and Z data from node i
'''
# Get the number of nodes
n = W.shape[0]
Comms_Data = []
for i in range(n):
Comms_Data.append({'WQ':[], 'up_ZQ':[], 'down_ZQ':[], 'up_BQ':[], 'down_BQ':[]})
for i in range(n):
comms_i = Comms_Data[i]
for j in range(i):
comms_j = Comms_Data[j]
if not np.isclose(W[i,j], 0, atol=1e-3):
if not np.isclose(Z[i,j], 0, atol=1e-3):
comms_i['up_BQ'].append(j)
comms_j['down_BQ'].append(i)
else:
comms_j['WQ'].append(i)
comms_i['WQ'].append(j)
elif not np.isclose(Z[i,j], 0, atol=1e-3):
comms_i['up_ZQ'].append(j)
comms_j['down_ZQ'].append(i)
return Comms_Data
#def solve(s, itrs=100, gamma=0.5, verbose=False, terminate=None):
def subproblem(i, data, resolvents, W, Z, comms_data, comm, gamma=0.5, itrs=100, vartol=None, verbose=False):
# comm = MPI.COMM_WORLD
# i = comm.Get_rank()
# size = comm.Get_size()
# L, W = oars.getMT(size)
# comms_data_all = requiredComms(L, W)
# comms_data = comms_data_all[i]
#s = 10
resolvent = resolvents(data)
s = resolvent.shape
#s = data.shape
buffer = np.ones(s, dtype=np.float64)
local_v = np.zeros(s, dtype=np.float64)
local_r = np.zeros(s, dtype=np.float64)
v_temp = np.zeros(s, dtype=np.float64)
n = W.shape[0]
itr = 0
t_itr = np.array(itrs, 'i')
terminated = False
while itr < itrs:
if vartol is not None and comm.Iprobe(source=n, tag=0):
itrs = comm.recv(source=n, tag=0)
terminated = True
if verbose and itr % 500 == 0:
print(f'Node {i} iteration {itr}', flush=True)
# Get data from upstream L queue
for k in comms_data['up_ZQ']:
req = comm.Irecv(buffer, source=k, tag=itr)
req.Wait()
local_r -= Z[i,k]*buffer
# Pull from the B queues, update r and v_temp
for k in comms_data['up_BQ']:
req = comm.Irecv(buffer, source=k, tag=itr)
req.Wait()
local_r -= Z[i,k]*buffer
v_temp += W[i,k]*buffer
# Solve the problem
w_value = resolvent.prox(local_v + local_r)
# Put data in downstream queues
for k in comms_data['down_ZQ']:
comm.Isend(w_value, dest=k, tag=itr)
for k in comms_data['down_BQ']:
comm.Isend(w_value, dest=k, tag=itr)
# Put data in upstream W queues
for k in comms_data['WQ']:
comm.Isend(w_value, dest=k, tag=itr)
for k in comms_data['up_BQ']:
comm.Isend(w_value, dest=k, tag=itr)
# Update v from all W queues
for k in comms_data['WQ']:
req = comm.Irecv(buffer, source=k, tag=itr)
req.Wait()
v_temp += W[i,k]*buffer
# Update v from all B queues
for k in comms_data['down_BQ']:
req = comm.Irecv(buffer, source=k, tag=itr)
req.Wait()
v_temp += W[i,k]*buffer
#v_temp += sum([W[i,k]*queue[k,i].get() for k in comms_data['down_BQ']])
local_v = local_v - gamma*(W[i,i]*w_value+v_temp)
# Terminate if needed
if i==0 and vartol is not None and not terminated:
#print(f'Node {i} w_value sending for eval: {w_value}', flush=True)
comm.Send(local_v, dest=n, tag=itr)
# Zero out v_temp without reallocating memory
v_temp.fill(0)
local_r.fill(0)
itr += 1
#print(f'Node {i} w_value: {w_value}', flush=True)
# return w_value and log if it is in the resolvent
if hasattr(resolvent, 'log'):
log = resolvent.log
else:
log = None
return w_value, log
def evaluate(n, shape, comm, vartol=1e-7, itrs=100):
"""
Evaluate the convergence of the algorithm and terminate if needed
Args:
s (tuple): the shape of the data
comm (MPI communicator): the MPI communicator
itrs (int): the number of iterations to run
"""
last = np.zeros(shape, dtype=np.float64)
buffer = np.zeros(shape, dtype=np.float64)
counter = 0
itr = 0
while counter < n and itr < itrs:
comm.Recv(buffer, source=0, tag=itr)
w = buffer.copy()
# Print last and buffer
#print(f'Counter: {counter}, Last: {last}, Buffer: {w}', flush=True)
if np.linalg.norm(w - last) < vartol:
counter += 1
else:
counter = 0
last = w
itr += 1
# print counter, last and buff
#print(f'Counter: {counter}, Last: {last}, Buffer: {w}', flush=True)
print(f'Reached termination criteria on Iteration {itr}', flush=True)
# Terminate the other processes
advance = n*2
terminate_itr = itr + advance
if itr < itrs - advance:
# t_itr = np.array(terminate_itr, 'i')
# comm.Bcast([t_itr, MPI.INT], root=n)
for i in range(n):
#print(f'Sending termination criteria {terminate_itr} to {i}', flush=True)
comm.send(terminate_itr, dest=i, tag=0)