package org.example;

import mpi.*;
import java.util.Random;

public class MPI_AVG {
    public static void main(String[] args) throws Exception {
        MPI.Init(args);

        // Ensure all processes are ready before starting the clock
        MPI.COMM_WORLD.Barrier();
        long startTime = System.nanoTime();

        int rank = MPI.COMM_WORLD.Rank();
        int size = MPI.COMM_WORLD.Size();

        int totalElements = 1000;

        // Calculate how many elements each process gets
        // and handle the remainder for the last process
        int baseSize = totalElements / size;
        int mySize = (rank == size - 1) ? (totalElements - baseSize * (size - 1)) : baseSize;

        double[] fullMatrix = null;
        double[] localBuffer = new double[mySize];

        // --- Task 1: Distribution ---
        if (rank == 0) {
            fullMatrix = new double[totalElements];
            Random rand = new Random();
            for (int i = 0; i < totalElements; i++) {
                fullMatrix[i] = rand.nextDouble() * 100;
            }

            // Manually send data to handle potential unequal sizes (Scatterv-like logic)
            int offset = baseSize;
            for (int i = 1; i < size; i++) {
                int sendSize = (i == size - 1) ? (totalElements - baseSize * (size - 1)) : baseSize;
                MPI.COMM_WORLD.Send(fullMatrix, offset, sendSize, MPI.DOUBLE, i, 0);
                offset += sendSize;
            }
            // Rank 0 keeps its own part
            System.arraycopy(fullMatrix, 0, localBuffer, 0, baseSize);
        } else {
            // Workers receive their specific part
            MPI.COMM_WORLD.Recv(localBuffer, 0, mySize, MPI.DOUBLE, 0, 0);
        }

        // Printing received elements
        System.out.println("Rank " + rank + " received " + mySize + " elements.");

        // --- Task 2: Local Average Calculation ---
        double localSum = 0;
        for (double val : localBuffer) {
            localSum += val;
        }
        double localAvg = localSum / mySize;
        System.out.println("Rank " + rank + " Local Average: " + String.format("%.2f", localAvg));

        // --- Task 3: Global Average using Local Averages (Weighted Average) ---
        // We need to gather all local averages and their counts to Rank 0
        double[] allLocalAvgs = new double[size];
        int[] allCounts = new int[size];

        // Gather local averages into an array on Rank 0
        MPI.COMM_WORLD.Gather(new double[]{localAvg}, 0, 1, MPI.DOUBLE, allLocalAvgs, 0, 1, MPI.DOUBLE, 0);
        // Gather the counts (weights) into an array on Rank 0
        MPI.COMM_WORLD.Gather(new int[]{mySize}, 0, 1, MPI.INT, allCounts, 0, 1, MPI.INT, 0);

        if (rank == 0) {
            double weightedSum = 0;
            for (int i = 0; i < size; i++) {
                // Applying the weighted average formula: Sum(Avg_i * Count_i)
                weightedSum += (allLocalAvgs[i] * allCounts[i]);
            }
            double globalAverage = weightedSum / totalElements;

            System.out.println("\n--- FINAL CALCULATION (Weighted Mean) ---");
            System.out.println("Global Average derived from local averages: " + String.format("%.4f", globalAverage));
            System.out.println("========================================\n");
        }

        MPI.COMM_WORLD.Barrier();
        long endTime = System.nanoTime();

        if (rank == 0) {
            double duration = (endTime - startTime) / 1_000_000.0;
            System.out.println("Execution Time for " + size + " processes: " + duration + " ms");
        }

        MPI.Finalize();
    }
}