#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <mpi.h>

#define ARRAY_SIZE 20

// Function to calculate local prefix sum for each block
void calculate_local_prefix_sum(int* block_array, int block_size, int* local_prefix) {
    local_prefix[0] = block_array[0];
    for (int i = 1; i < block_size; i++) {
        local_prefix[i] = local_prefix[i - 1] + block_array[i];
    }
}

// Function to calculate MPI-based prefix sum for a block
void prefix_mpi(int* block_array, int block_size, int* block_prefix, MPI_Comm communicator) {
    int rank, size;
    MPI_Comm_rank(communicator, &rank);
    MPI_Comm_size(communicator, &size);

    int* local_prefix = (int*)malloc(block_size * sizeof(int));
    local_prefix[0] = block_array[0];

    // Calculate local prefix sum for the block
    calculate_local_prefix_sum(block_array, block_size, local_prefix);

    int* prefix_sums = (int*)malloc(size * sizeof(int));

    // Gather the last element of each local prefix to compute the prefix sums at rank 0
    MPI_Gather(&local_prefix[block_size - 1], 1, MPI_INT, prefix_sums, 1, MPI_INT, 0, communicator);

    int accum = 0;

    // Rank 0 calculates the prefix sums from gathering data and broadcasts them to all processes
    if (rank == 0) {
        for (int i = 0; i < size; i++) {
            accum += prefix_sums[i];
            prefix_sums[i] = accum;
        }
    }

    MPI_Bcast(prefix_sums, size, MPI_INT, 0, communicator);

    // Calculate the final prefix sum for the block using the calculated prefix sums
    for (int i = 0; i < block_size; i++) {
        block_prefix[i] = local_prefix[i] + prefix_sums[rank] - local_prefix[block_size - 1];
    }

    // Free allocated memory
    free(local_prefix);
    free(prefix_sums);
}

// Function to calculate sequential prefix sum
int* calculate_sequential_prefix_sum(int* array, int size) {
    int* prefix_sum = (int*)malloc(size * sizeof(int));
    prefix_sum[0] = array[0];

    for (int i = 1; i < size; i++) {
        prefix_sum[i] = prefix_sum[i - 1] + array[i];
    }

    return prefix_sum;
}


int main(int argc, char** args) {
    MPI_Init(&argc, &args);

    int my_rank;
    int com_size;
    MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
    MPI_Comm_size(MPI_COMM_WORLD, &com_size);

    int total_array_size = 2048;

    // Adjust the total array size to be a multiple of the number of processes
    if (total_array_size % com_size != 0)
        total_array_size = (total_array_size / com_size + 1) * com_size;

    int block_size = total_array_size / com_size;
    int* total_array = NULL;
    int* total_prefix = NULL;

    // Only process 0 initializes and fills the total array with random values
    if (my_rank == 0) {
        total_array = (int*)malloc(total_array_size * sizeof(int));
        total_prefix = (int*)malloc(total_array_size * sizeof(int));

        // Fill the total array with random values
        for (int i = 0; i < total_array_size; i++)
            total_array[i] = rand() % 11;
    }

    int* block_array = (int*)malloc(block_size * sizeof(int));
    int* block_prefix = (int*)malloc(block_size * sizeof(int));

    // Scatter the total array among processes
    MPI_Scatter(total_array, block_size, MPI_INT, block_array, block_size, MPI_INT, 0, MPI_COMM_WORLD);

    // Calculate the local prefix sum for the block
    prefix_mpi(block_array, block_size, block_prefix, MPI_COMM_WORLD);

    // Gather the local prefix sums to compute the total prefix array
    MPI_Gather(block_prefix, block_size, MPI_INT, total_prefix, block_size, MPI_INT, 0, MPI_COMM_WORLD);

    int accum = 0;
    if (my_rank == 0) {
        // Verify the correctness of the MPI-based prefix sum
        for (int i = 0; i < total_array_size; i++) {
            accum += total_array[i];
            if (total_prefix[i] != accum)
                printf("Error at index %i: %i expected, %i computed\n", i, accum, total_prefix[i]);
        }

        printf("Test completed!\n");

        // Sequential prefix sum calculation for verification
        int* seq_prefix = calculate_sequential_prefix_sum(total_array, total_array_size);
        for (int i = 0; i < total_array_size; i++) {
            if (seq_prefix[i] != total_prefix[i])
                printf("Verification Error at index %i: %i expected, %i computed\n", i, seq_prefix[i], total_prefix[i]);
        }
        free(seq_prefix);

        // Free allocated memory
        free(total_array);
        free(total_prefix);
    }
    free(block_array);
    free(block_prefix);

    MPI_Finalize();

    return 0;
}
