import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.Scanner;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

public class ParallelSum
{
  public static void main (String[] args)
  {
    System.out.print ("p ==> ");
    Scanner input = new Scanner (System.in);
    int numThreads = input.nextInt ();
    System.out.println ("(hardware threads = "
                        + Runtime.getRuntime ().availableProcessors () + ")");
    System.out.print ("N ==> ");
    int N = input.nextInt ();
    System.out.println ();
    input.close ();

    int[] A = new int[N];
    Random rand = new Random ();
    final int UPPER_BOUND = 4;
    for (int i = 0; i < N; ++i)
      A[i] = rand.nextInt (UPPER_BOUND + 1);

    long start = System.currentTimeMillis ();
    //ExecutorService executor = Executors.newCachedThreadPool ();
    ExecutorService executor = Executors.newFixedThreadPool (numThreads);
    List<Future<Integer>> futures = new ArrayList<> (numThreads);
    for (int i = 0; i < numThreads; ++i)
    {
      //Callable<Integer> task = new LocalSum (i, numThreads, A);
      // Local vars referenced in a lambda must be final 
      final int threadId = i; 
      Callable<Integer> task =
        () ->
        {
          return sumLocal (threadId, numThreads, A); 
        };
      futures.add (executor.submit (task));
    }
    int parallelSum = 0;
    for (Future<Integer> f : futures)
    {
      try
      {
        parallelSum += f.get ();
      }
      catch (InterruptedException | ExecutionException e)
      {
        e.printStackTrace ();
      }
    }
    long stop = System.currentTimeMillis ();
    long elapsedMs = stop - start;
    System.out.println ("// sum:       " + parallelSum);
    System.out.println ("// time:      " + elapsedMs + " ms\n");

    start = System.currentTimeMillis ();
    int serialSum = serialSum (A);
    stop = System.currentTimeMillis ();
    System.out.println ("Serial sum:   " + serialSum);
    elapsedMs = stop - start;
    System.out.println ("Serial time:  " + elapsedMs + " ms");

    executor.shutdown ();
  }

  private static int sumLocal (int id, int numThreads, final int[] A)
  {
      int lowerBound = id * A.length / numThreads;
      int upperBound = (id + 1) * A.length / numThreads;

      int localSum = 0;
      for (int i = lowerBound; i < upperBound; ++i)
        localSum += A[i];
      return localSum;
  }
  
  private static int serialSum (final int[] A)
  {
    int sum = 0;
    for (int e : A)
      sum += e;
    return sum;
  }

  static class LocalSum implements Callable<Integer>
  {
    int id;
    int numThreads;
    int[] A;

    LocalSum (int id, int numThreads, int[] A)
    {
      this.id = id;
      this.numThreads = numThreads;
      this.A = A;
    }

    @Override
    public Integer call () throws Exception
    {
      int lowerBound = id * A.length / numThreads;
      int upperBound = (id + 1) * A.length / numThreads;

      int localSum = 0;
      for (int i = lowerBound; i < upperBound; ++i)
        localSum += A[i];
      return localSum;
    }
  }
}
