#ifndef LLVM_SUPPORT_PARALLEL_H
#define LLVM_SUPPORT_PARALLEL_H
#include "llvm/ADT/STLExtras.h"
#include "llvm/Config/llvm-config.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/Threading.h"
#include <algorithm>
#include <condition_variable>
#include <functional>
#include <mutex>
namespace llvm {
namespace parallel {
extern ThreadPoolStrategy strategy;
namespace detail {
#if LLVM_ENABLE_THREADS
class Latch {
uint32_t Count;
mutable std::mutex Mutex;
mutable std::condition_variable Cond;
public:
explicit Latch(uint32_t Count = 0) : Count(Count) {}
~Latch() {
assert(Count == 0);
}
void inc() {
std::lock_guard<std::mutex> lock(Mutex);
++Count;
}
void dec() {
std::lock_guard<std::mutex> lock(Mutex);
if (--Count == 0)
Cond.notify_all();
}
void sync() const {
std::unique_lock<std::mutex> lock(Mutex);
Cond.wait(lock, [&] { return Count == 0; });
}
};
class TaskGroup {
Latch L;
bool Parallel;
public:
TaskGroup();
~TaskGroup();
void spawn(std::function<void()> f);
void sync() const { L.sync(); }
};
const ptrdiff_t MinParallelSize = 1024;
template <class RandomAccessIterator, class Comparator>
RandomAccessIterator medianOf3(RandomAccessIterator Start,
RandomAccessIterator End,
const Comparator &Comp) {
RandomAccessIterator Mid = Start + (std::distance(Start, End) / 2);
return Comp(*Start, *(End - 1))
? (Comp(*Mid, *(End - 1)) ? (Comp(*Start, *Mid) ? Mid : Start)
: End - 1)
: (Comp(*Mid, *Start) ? (Comp(*(End - 1), *Mid) ? Mid : End - 1)
: Start);
}
template <class RandomAccessIterator, class Comparator>
void parallel_quick_sort(RandomAccessIterator Start, RandomAccessIterator End,
const Comparator &Comp, TaskGroup &TG, size_t Depth) {
if (std::distance(Start, End) < detail::MinParallelSize || Depth == 0) {
llvm::sort(Start, End, Comp);
return;
}
auto Pivot = medianOf3(Start, End, Comp);
std::swap(*(End - 1), *Pivot);
Pivot = std::partition(Start, End - 1, [&Comp, End](decltype(*Start) V) {
return Comp(V, *(End - 1));
});
std::swap(*Pivot, *(End - 1));
TG.spawn([=, &Comp, &TG] {
parallel_quick_sort(Start, Pivot, Comp, TG, Depth - 1);
});
parallel_quick_sort(Pivot + 1, End, Comp, TG, Depth - 1);
}
template <class RandomAccessIterator, class Comparator>
void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End,
const Comparator &Comp) {
TaskGroup TG;
parallel_quick_sort(Start, End, Comp, TG,
llvm::Log2_64(std::distance(Start, End)) + 1);
}
enum { MaxTasksPerGroup = 1024 };
template <class IterTy, class ResultTy, class ReduceFuncTy,
class TransformFuncTy>
ResultTy parallel_transform_reduce(IterTy Begin, IterTy End, ResultTy Init,
ReduceFuncTy Reduce,
TransformFuncTy Transform) {
size_t NumInputs = std::distance(Begin, End);
if (NumInputs == 0)
return std::move(Init);
size_t NumTasks = std::min(static_cast<size_t>(MaxTasksPerGroup), NumInputs);
std::vector<ResultTy> Results(NumTasks, Init);
{
TaskGroup TG;
size_t TaskSize = NumInputs / NumTasks;
size_t RemainingInputs = NumInputs % NumTasks;
IterTy TBegin = Begin;
for (size_t TaskId = 0; TaskId < NumTasks; ++TaskId) {
IterTy TEnd = TBegin + TaskSize + (TaskId < RemainingInputs ? 1 : 0);
TG.spawn([=, &Transform, &Reduce, &Results] {
ResultTy R = Init;
for (IterTy It = TBegin; It != TEnd; ++It)
R = Reduce(R, Transform(*It));
Results[TaskId] = R;
});
TBegin = TEnd;
}
assert(TBegin == End);
}
ResultTy FinalResult = std::move(Results.front());
for (ResultTy &PartialResult :
makeMutableArrayRef(Results.data() + 1, Results.size() - 1))
FinalResult = Reduce(FinalResult, std::move(PartialResult));
return std::move(FinalResult);
}
#endif
} }
template <class RandomAccessIterator,
class Comparator = std::less<
typename std::iterator_traits<RandomAccessIterator>::value_type>>
void parallelSort(RandomAccessIterator Start, RandomAccessIterator End,
const Comparator &Comp = Comparator()) {
#if LLVM_ENABLE_THREADS
if (parallel::strategy.ThreadsRequested != 1) {
parallel::detail::parallel_sort(Start, End, Comp);
return;
}
#endif
llvm::sort(Start, End, Comp);
}
void parallelFor(size_t Begin, size_t End, function_ref<void(size_t)> Fn);
template <class IterTy, class FuncTy>
void parallelForEach(IterTy Begin, IterTy End, FuncTy Fn) {
parallelFor(0, End - Begin, [&](size_t I) { Fn(Begin[I]); });
}
template <class IterTy, class ResultTy, class ReduceFuncTy,
class TransformFuncTy>
ResultTy parallelTransformReduce(IterTy Begin, IterTy End, ResultTy Init,
ReduceFuncTy Reduce,
TransformFuncTy Transform) {
#if LLVM_ENABLE_THREADS
if (parallel::strategy.ThreadsRequested != 1) {
return parallel::detail::parallel_transform_reduce(Begin, End, Init, Reduce,
Transform);
}
#endif
for (IterTy I = Begin; I != End; ++I)
Init = Reduce(std::move(Init), Transform(*I));
return std::move(Init);
}
template <class RangeTy,
class Comparator = std::less<decltype(*std::begin(RangeTy()))>>
void parallelSort(RangeTy &&R, const Comparator &Comp = Comparator()) {
parallelSort(std::begin(R), std::end(R), Comp);
}
template <class RangeTy, class FuncTy>
void parallelForEach(RangeTy &&R, FuncTy Fn) {
parallelForEach(std::begin(R), std::end(R), Fn);
}
template <class RangeTy, class ResultTy, class ReduceFuncTy,
class TransformFuncTy>
ResultTy parallelTransformReduce(RangeTy &&R, ResultTy Init,
ReduceFuncTy Reduce,
TransformFuncTy Transform) {
return parallelTransformReduce(std::begin(R), std::end(R), Init, Reduce,
Transform);
}
template <class RangeTy, class FuncTy>
Error parallelForEachError(RangeTy &&R, FuncTy Fn) {
return unwrap(parallelTransformReduce(
std::begin(R), std::end(R), wrap(Error::success()),
[](LLVMErrorRef Lhs, LLVMErrorRef Rhs) {
return wrap(joinErrors(unwrap(Lhs), unwrap(Rhs)));
},
[&Fn](auto &&V) { return wrap(Fn(V)); }));
}
}
#endif