from mpi4py import MPI
import numpy as np

def reduce_tree(sendbuf, root=0):
    """
    Tree-based reduce (sum) for a 1D numpy array
    Returns result only at root
    """
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()

    local = sendbuf.copy()

    left = 2*rank + 1
    right = 2*rank + 2
    parent = (rank-1)//2 if rank != root else None

    
    if left < size:
        tmp = np.empty_like(local)
        comm.Recv(tmp, source=left, tag=0)
        local += tmp

    
    if right < size:
        tmp = np.empty_like(local)
        comm.Recv(tmp, source=right, tag=0)
        local += tmp

    
    if rank != root:
        comm.Send(local, dest=parent, tag=0)
        return None
    else:
        return local

def sequential_reduce(all_data):
    return np.sum(all_data, axis=0)

if __name__ == "__main__":
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()

    import sys
    if len(sys.argv) != 2:
        if rank == 0:
            print("Usage: python reduce.py <array_size>")
        sys.exit(0)

    N = int(sys.argv[1])
    np.random.seed(rank+1)
    local_array = np.random.randint(0, 10, size=N)

    comm.Barrier()
    t1 = MPI.Wtime()
    result = reduce_tree(local_array)
    comm.Barrier()
    t2 = MPI.Wtime()
    parallel_time = t2 - t1

    
    all_data = comm.gather(local_array, root=0)

    if rank == 0:
        t3 = MPI.Wtime()
        seq_result = sequential_reduce(np.array(all_data))
        t4 = MPI.Wtime()
        sequential_time = t4 - t3

        print(f"Processes: {size}")
        print(f"Array size: {N}")
        print(f"Parallel reduce time: {parallel_time:.6f} s")
        print(f"Sequential reduce time: {sequential_time:.6f} s")
       
        correct = np.array_equal(seq_result, result)
        print(f"Correct: {correct}")
