VIPRA Documentation
Loading...
Searching...
No Matches
parameter_sweep.hpp
1#pragma once
2
3#ifdef VIPRA_USE_MPI
4#include <mpi.h>
5#endif
6
7#include <cstddef>
8#include <string>
9#include <type_traits>
10
11#include "vipra/logging/logging.hpp"
12#include "vipra/modules/serializable.hpp"
13#include "vipra/parameter_sweep/ps_util.hpp"
14#include "vipra/simulation/sim_type.hpp"
15#include "vipra/special_modules/parameters.hpp"
16#include "vipra/types/util/result_or_void.hpp"
17#include "vipra/util/timing.hpp"
18
19namespace VIPRA {
21 public:
22 static void initialize(int argc, char** argv)
23 {
24#ifdef VIPRA_USE_MPI
25 MPI_Init(&argc, &argv);
26 MPI_Comm_dup(MPI_COMM_WORLD, &comm);
27 MPI_Comm_rank(comm, &rank);
28 MPI_Comm_size(comm, &size);
29 Log::info("MPI Initialized, rank: {}, size: {}", rank, size);
30#else
31 rank = 0;
32 size = 1;
33#endif
34 }
35
47 static void run(std::string const& installPath, std::string const& modulesPath,
48 std::string const& pedPath, std::string const& mapPath,
49 std::string const& paramsPath, size_t count, auto&& callback = VOID{})
50 {
51 // Create the simulation and load the modules
53 sim.set_install_dir(installPath);
54 sim.set_modules(modulesPath);
55
56 Parameters params;
57
58 _mpiTimings.start_new();
59 _mpiTimings.pause();
60
61 _inputTimings.start_new();
62 load_inputs(params.get_input(), paramsPath);
63 // disseminate_input(params.get_input());
64 load_inputs(sim.get_map_input(), mapPath);
65 if ( ! pedPath.empty() ) load_inputs(sim.get_ped_input(), pedPath);
66 _inputTimings.stop();
67
68 size_t localCount = sim_count(rank, size, count);
69
70 // add the correct simulation number for the current worker
71 // add, because this may be called multiple times
72 sim.add_sim_id(start_sim_id(rank, size, count));
73
74 for ( size_t i = 0; i < localCount; ++i ) {
75 _timings.start_new();
76 // run the simulation
77 // if a callback is provided, call that on completion
78 if constexpr ( std::is_same_v<decltype(callback), VIPRA::VOID> ) {
79 sim.run_sim(params);
80 }
81 else {
82 sim.run_sim(params);
83 callback(sim.get_sim_id());
84 }
85 _timings.stop();
86
87 sim.reset_modules();
88 }
89
90 // update each worker to the correct sim count
91 sim.set_sim_id(count);
92
93 sim.output_timings();
94 _timings.output_timings();
95
96#ifdef VIPRA_USE_MPI
97 _mpiTimings.resume();
98 MPI_Barrier(MPI_COMM_WORLD);
99 _mpiTimings.stop();
100
101 _mpiTimings.output_timings();
102#endif
103 }
104
105 [[nodiscard]] static auto get_rank() -> int { return rank; }
106 [[nodiscard]] static auto get_size() -> int { return size; }
107 [[nodiscard]] static auto is_parallel() -> bool { return size > 1; }
108 [[nodiscard]] static auto is_root() -> bool { return rank == 0; }
109
110 private:
111 struct DeferredFinalize {
112 DeferredFinalize(DeferredFinalize const&) = default;
113 DeferredFinalize(DeferredFinalize&&) = default;
114 auto operator=(DeferredFinalize const&) -> DeferredFinalize& = default;
115 auto operator=(DeferredFinalize&&) -> DeferredFinalize& = default;
116 DeferredFinalize() = default;
117 ~DeferredFinalize()
118 {
119#ifdef VIPRA_USE_MPI
120 int flag = 0;
121 MPI_Initialized(&flag);
122 if ( flag ) MPI_Finalize();
123#endif
124 }
125 };
126
127// NOLINTBEGIN
128#ifdef VIPRA_USE_MPI
129 static MPI_Comm comm;
130#endif
131
132 static Util::Timings _timings;
133 static Util::Timings _mpiTimings;
134 static Util::Timings _inputTimings;
135
136 static int rank;
137 static int size;
138 static DeferredFinalize _finalize;
139 // NOLINTEND
140
147 template <typename input_t>
148 static void load_inputs(input_t& input, std::string const& filepath)
149 {
150 input.load(filepath);
151 }
152
153 static void disseminate_input(Modules::Serializable& input)
154 {
155 std::string serialized{};
156 int length{};
157 if ( rank == 0 ) {
158 serialized = input.serialize();
159 length = static_cast<int>(serialized.size());
160 }
161#ifdef VIPRA_USE_MPI
162 _mpiTimings.resume();
163 MPI_Bcast(&length, 1, MPI_INT, 0, comm);
164
165 if ( rank != 0 ) {
166 serialized.resize(length);
167 }
168
169 MPI_Bcast(serialized.data(), length, MPI_CHAR, 0, comm);
170 _mpiTimings.pause();
171
172 if ( rank != 0 ) {
173 input.parse(serialized);
174 }
175#endif
176 }
177};
178} // namespace VIPRA
static VIPRA_INLINE void info(fmt::format_string< param_ts... > message, param_ts &&... params)
Calls the provided Logger with Level INFO.
Definition logging.hpp:68
Definition parameter_sweep.hpp:20
static void run(std::string const &installPath, std::string const &modulesPath, std::string const &pedPath, std::string const &mapPath, std::string const &paramsPath, size_t count, auto &&callback=VOID{})
Runs a parameter sweep over the worker nodes.
Definition parameter_sweep.hpp:47
Definition parameters.hpp:20
Definition sim_type.hpp:29
Placeholder Type for void.
Definition result_or_void.hpp:10