from mpi4py import MPI
import numpy as np
import time 


# Perform tree-based reduction on local_data using addition as the operation.
def reduce_tree(local_data, rank, size, comm):

    num_processes = size
    my_rank = rank

    print(f"[+] Process {my_rank} starting with local data: {local_data}")

    step = 1
    while step < num_processes:

        # XOR operation to find partner
        partner = my_rank ^ step  
        
        if partner < num_processes:
            
        
            # Exchange data with the partner
            recv_data = np.zeros_like(local_data)
            comm.Sendrecv(local_data, dest=partner, recvbuf=recv_data, source=partner)
            print(f"[+] Process {my_rank} received data from process {partner}: {recv_data}")
            
            # Combine received data with the local data
            local_data += recv_data
            
            print(f"[+] Process {my_rank} updated local data: {local_data}")

        
        # Move to the next step 
        step *= 2
    
    if my_rank == 0:
        return local_data
    else:
        return None

def main():
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()

    # Initialize the local data for each process
    local_data = np.array([rank + 1], dtype=int)

    # Start timer
    start_time = time.time()

    # Perform the tree based reduce operation
    result = reduce_tree(local_data, rank, size, comm)
    
    # end timer
    end_time = time.time()

    # Elapsed time
    elapsed_time = end_time - start_time

    if rank == 0:
        print(f"[+] Final result after reduction: {result}")
        print(f"[+] Time taken for the reduction: {elapsed_time:.6f} seconds")

# Run the main function
if __name__ == "__main__":
    main()
