#include <string>
#include <thread>
#include <iostream>
#include "synchronization.h"

using namespace std;

// To keep track how many threads finished
volatile int threadsFinished;

void kernel(int nIterations, int nThreads, bool synced) {
    int result = nIterations * nThreads;
    auto threads = new thread*[nThreads];
    Workgroup workgroup;
    workgroup.setIterationsPerThread(nIterations);
    workgroup.setIsSynchronized(synced);
	threadsFinished = 0;
	auto begin = chrono::high_resolution_clock::now();

    for (int i = 0; i < nThreads; i++) {
        auto t = new thread(&Workgroup::doWork, &workgroup);
        t->detach();
        threads[i] = t;
    }

    // Wait for the threads to finish
    while (threadsFinished < nThreads);

	auto end = chrono::high_resolution_clock::now();
	auto elapsed = chrono::duration_cast<chrono::milliseconds>(end - begin).count();
   	cout << "  Result for " << result << " iterations (" << nThreads << " threads):" << endl;
   	cout << "  - Sync      : " << (synced ? "yes" : "no") << endl;
   	cout << "  - Time [ms] : " << elapsed << endl;
   	cout << "  - Value     : " << workgroup.getSigma() << endl;
   	cout << "  - Valid     : " << (result == workgroup.getSigma() ? "yes" : "no") << endl;
    cout << endl;

    for (int i = 0; i < nThreads; i++) {
    	delete threads[i];
    }

    delete[] threads;
}

int receive(const string& message) {
	cout << "  " << message << ": ";
	int num;
	cin >> num;

	if (!cin.fail())
		return num;

	cout << "  Invalid input! Only integer numbers are valid." << endl;
    return receive(message);
}

void Synchronization::run() {
    cout << "> Running the synchronization exercise" << endl;
    int n = receive("Iterations per threads");
    int p = receive("Number of threads");
    kernel(n, p, false);
    kernel(n, p, true);
}

Workgroup::Workgroup() {
    sigma = 0;
    // Just use two lambdas as functions here
    workers[true] = [&] { sync.lock(); sigma++; sync.unlock(); };
    workers[false] = [&] { sigma++; };
}

int Workgroup::getSigma() { 
	return sigma;
}

int Workgroup::getIterationsPerThread() { 
	return iterationsPerThreads;
}

void Workgroup::setIterationsPerThread(int value) { 
	iterationsPerThreads = value; 
}

bool Workgroup::getIsSynchronized() { 
	return synchronized; 
}

void Workgroup::setIsSynchronized(bool value) { 
	synchronized = value; 
}

void Workgroup::doWork() {
	// Pick the right function for the body of the loop
    auto worker = workers[synchronized];

    for (int i = 0; i < iterationsPerThreads; ++i)
        worker();

    // Use this GCC macro for improved increment with locking
    __sync_add_and_fetch(&threadsFinished, 1);
}