package com.distributed.search.grpc;
import java.net.InetSocketAddress;
import com.distributed.search.model.*;
import io.grpc.ManagedChannel;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import java.util.*;

/**
 * Acts as the Coordinator Client that manages connections to Workers
 * and aggregates search results.
 */

public class SearchClient {
    private final Map<String, SearchServiceGrpc.SearchServiceBlockingStub> stubs = new HashMap<>();
    private final List<ManagedChannel> channels = new ArrayList<>();
    /**
     * Updates the active worker list and establishes gRPC channels.
     */
    public void updateWorkers(List<String> workerAddresses) {
        for (ManagedChannel channel : channels) {
            channel.shutdownNow();
        }
        channels.clear();
        stubs.clear();

        for (String address : workerAddresses) {
            try {
                String[] parts = address.split(":");
                String host = parts[0];
                int port = Integer.parseInt(parts[1]);

                ManagedChannel channel = NettyChannelBuilder.forAddress(new InetSocketAddress(host, port))
                        .usePlaintext()
                        .build();

                channels.add(channel);
                stubs.put(address, SearchServiceGrpc.newBlockingStub(channel));
                System.out.println("Successfully connected to Worker: " + host + ":" + port);

            } catch (Exception e) {
                System.err.println("Failed to connect to " + address + ": " + e.getMessage());
                e.printStackTrace();
            }
        }
    }
    /**
     * Orchestrates the 2-Phase Distributed Search.
     */
    public List<SearchResponse.DocumentResult> performSearch(List<String> terms, List<String> allFiles) {
        if (stubs.isEmpty()) {
            return Collections.emptyList();
        }

        // Phase 1: Aggregate Global Counts
        Map<String, Integer> globalTermCounts = new HashMap<>();
        int filesPerWorker = (int) Math.ceil((double) allFiles.size() / stubs.size());
        int currentFileIndex = 0;
        List<String> workerList = new ArrayList<>(stubs.keySet());

        for (String address : workerList) {
            int count = Math.min(filesPerWorker, allFiles.size() - currentFileIndex);
            if (count <= 0) break;

            StatRequest request = StatRequest.newBuilder()
                    .addAllTerms(terms)
                    .setStartIndex(currentFileIndex)
                    .setCount(count)
                    .build();
            try {
                StatResponse response = stubs.get(address).getDocumentStats(request);
                response.getTermToDocumentCountMap().forEach((term, docCount) ->
                        globalTermCounts.merge(term, docCount, Integer::sum));
            } catch (Exception e) { System.err.println("Worker " + address + " Phase 1 error"); }
            currentFileIndex += count;
        }

        // --- Calculate Global IDF ---
        Map<String, Double> globalIdfs = new HashMap<>();
        for (String term : terms) {
            int docsWithTerm = globalTermCounts.getOrDefault(term, 0);
            double idf = Math.log10((double) allFiles.size() / Math.max(1, docsWithTerm));
            globalIdfs.put(term, idf);
        }

        // --- Phase 2: Final Scoring ---
        List<SearchResponse.DocumentResult> finalResults = new ArrayList<>();
        currentFileIndex = 0;
        for (String address : workerList) {
            int count = Math.min(filesPerWorker, allFiles.size() - currentFileIndex);
            if (count <= 0) break;

            CalculationRequest request = CalculationRequest.newBuilder()
                    .addAllTerms(terms)
                    .putAllGlobalIdfs(globalIdfs)
                    .setStartIndex(currentFileIndex)
                    .setCount(count)
                    .build();
            try {
                SearchResponse response = stubs.get(address).getFinalScores(request);
                finalResults.addAll(response.getResultsList());
            } catch (Exception e) { System.err.println("Worker " + address + " Phase 2 error"); }
            currentFileIndex += count;
        }

        // --- return the sorted list ---
        finalResults.sort((a, b) -> Double.compare(b.getScore(), a.getScore()));
        return finalResults;
    }
}