Line data Source code
1 0 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : #pragma once
5 :
6 : #include <memory>
7 : #include <pup.h>
8 : #include <string>
9 : #include <unordered_map>
10 : #include <unordered_set>
11 : #include <utility>
12 :
13 : #include "ControlSystem/Tags/IsActiveMap.hpp"
14 : #include "DataStructures/DataBox/DataBox.hpp"
15 : #include "DataStructures/DataVector.hpp"
16 : #include "Domain/FunctionsOfTime/FunctionOfTime.hpp"
17 : #include "Parallel/GlobalCache.hpp"
18 : #include "Utilities/ErrorHandling/Assert.hpp"
19 : #include "Utilities/TMPL.hpp"
20 : #include "Utilities/TaggedTuple.hpp"
21 :
22 : /// \cond
23 : namespace domain::Tags {
24 : struct FunctionsOfTime;
25 : } // namespace domain::Tags
26 : namespace control_system::Tags {
27 : struct UpdateAggregators;
28 : struct SystemToCombinedNames;
29 : struct MeasurementTimescales;
30 : } // namespace control_system::Tags
31 : /// \endcond
32 :
33 : namespace control_system {
34 : /// \ingroup ControlSystemGroup
35 : /// Updates a FunctionOfTime in the global cache. Intended to be used in
36 : /// Parallel::mutate.
37 1 : struct UpdateSingleFunctionOfTime {
38 0 : static void apply(
39 : gsl::not_null<std::unordered_map<
40 : std::string,
41 : std::unique_ptr<domain::FunctionsOfTime::FunctionOfTime>>*>
42 : f_of_t_list,
43 : const std::string& f_of_t_name, double update_time,
44 : DataVector update_deriv, double new_expiration_time);
45 : };
46 :
47 : /*!
48 : * \ingroup ControlSystemGroup
49 : * \brief Updates several FunctionOfTimes in the global cache at once. Intended
50 : * to be used in Parallel::mutate.
51 : *
52 : * \details All functions of time are updated at the same `update_time`. For the
53 : * `update_args`, the keys of the map are the names of the functions of time.
54 : * The value `std::pair<DataVector, double>` for each key is the updated
55 : * derivative for the function of time and the new expiration time,
56 : * respectively.
57 : */
58 1 : struct UpdateMultipleFunctionsOfTime {
59 0 : static void apply(
60 : gsl::not_null<std::unordered_map<
61 : std::string,
62 : std::unique_ptr<domain::FunctionsOfTime::FunctionOfTime>>*>
63 : f_of_t_list,
64 : double update_time,
65 : const std::unordered_map<std::string, std::pair<DataVector, double>>&
66 : update_args);
67 : };
68 :
69 : /*!
70 : * \brief A class for collecting and storing information related to updating
71 : * functions of time and measurement timescales.
72 : *
73 : * \details This class also determines if enough data has been received in order
74 : * for the functions of time and measurement timescales to be updated. There
75 : * should be one `UpdateAggregator` for every group of control systems that have
76 : * the same `control_system::protocols::Measurement`.
77 : */
78 1 : struct UpdateAggregator {
79 0 : UpdateAggregator() = default;
80 :
81 : /*!
82 : * \brief Construct a new UpdateAggregator using a set of active control
83 : * system names and the combined name for the measurement.
84 : *
85 : * It is expected that all the control systems in this set use the same
86 : * `control_system::protocols::Measurement`.
87 : */
88 1 : UpdateAggregator(std::string combined_name,
89 : std::unordered_set<std::string> active_control_system_names);
90 :
91 : /*!
92 : * \brief Inserts and stores information for one of the control systems that
93 : * this class was constructed with.
94 : *
95 : * \param control_system_name Name of control system to add information for
96 : * \param new_measurement_timescale DataVector of new measurement timescales
97 : * calculated from a control system update.
98 : * \param new_measurement_expiration_time New measurement expiration time
99 : * calculated during that update.
100 : * \param control_signal New highest derivative for the function of time
101 : * calculated during that update (will be `std::move`ed).
102 : * \param new_fot_expiration_time New function of time expiration time
103 : * calculated for during that update
104 : */
105 1 : void insert(const std::string& control_system_name,
106 : const DataVector& new_measurement_timescale,
107 : double new_measurement_expiration_time, DataVector control_signal,
108 : double new_fot_expiration_time);
109 :
110 : /*!
111 : * \brief Checks if `insert` has been called for all control systems that this
112 : * class was constructed with.
113 : *
114 : * The `active_map` is used to inform the aggregator which controls systems
115 : * are currently active. This is separate so that individual control
116 : * systems can be enabled and disabled during an evolution.
117 : */
118 1 : bool is_ready(const std::unordered_map<std::string, bool>& active_map) const;
119 :
120 : /*!
121 : * \brief Returns a sorted concatenation of the control system names this
122 : * class was constructed with.
123 : */
124 1 : const std::string& combined_name() const;
125 :
126 : /*!
127 : * \brief Once `is_ready` is true, returns a map between the control system
128 : * name and a `std::pair` containing the `control_signal` that was passed to
129 : * `insert` and the minimum of all the `new_fot_expiration_time`s passed to
130 : * `insert`.
131 : *
132 : * \details This function is expected to only be called when `is_ready` is
133 : * true. It also must be called before `combined_measurement_expiration_time`.
134 : *
135 : * The `active_map` is used to inform the aggregator which controls systems
136 : * are currently active. This is separate so that individual controls
137 : * systems can be enabled and disabled during an evolution.
138 : */
139 : std::unordered_map<std::string, std::pair<DataVector, double>>
140 1 : combined_fot_expiration_times(
141 : const std::unordered_map<std::string, bool>& active_map) const;
142 :
143 : /*!
144 : * \brief Once `is_ready` is true, returns a `std::pair` containing the
145 : * minimum of all `new_measurement_timescale`s passed to `insert` and the
146 : * minimum of all `new_measurement_expiration_time`s passed to `insert`.
147 : *
148 : * \details This function is expected to be called only when `is_ready` is
149 : * true and only a single time once all control active control
150 : * systems for this measurement have computed their update values. It also
151 : * must be called after `combined_fot_expiration_times`. This function clears
152 : * all stored data when it is called.
153 : *
154 : * The `active_map` is used to inform the aggregator which controls systems
155 : * are currently active. This is separate so that individual controls
156 : * systems can be enabled and disabled during an evolution.
157 : */
158 1 : std::pair<double, double> combined_measurement_expiration_time(
159 : const std::unordered_map<std::string, bool>& active_map);
160 :
161 : /// \cond
162 : void pup(PUP::er& p);
163 : /// \endcond
164 :
165 : private:
166 : std::unordered_map<std::string, std::pair<std::pair<DataVector, double>,
167 : std::pair<double, double>>>
168 0 : expiration_times_{};
169 0 : std::unordered_set<std::string> active_names_{};
170 0 : std::string combined_name_{};
171 : };
172 :
173 : /*!
174 : * \brief Simple action that updates the appropriate `UpdateAggregator` for the
175 : * templated `ControlSystem`.
176 : *
177 : * \details The `new_measurement_timescale`, `new_measurement_expiration_time`,
178 : * `control_signal`, and `new_fot_expiration_time` are passed along to the
179 : * `UpdateAggregator::insert` function. The `old_measurement_expiration_time`
180 : * and `old_fot_expiration_time` are only used when the
181 : * `UpdateAggregator::is_ready` in order to update the functions of time and the
182 : * measurement timescale.
183 : *
184 : * When the `UpdateAggregator::is_ready`, the measurement timescale is mutated
185 : * with `UpdateSingleFunctionOfTime` and the functions of time are mutated with
186 : * `UpdateMultipleFunctionsOfTime`, both using `Parallel::mutate`.
187 : *
188 : * The "appropriate" `UpdateAggregator` is chosen from the
189 : * `control_system::Tags::SystemToCombinedNames` for the templated
190 : * `ControlSystem`.
191 : */
192 : template <typename ControlSystem>
193 1 : struct AggregateUpdate {
194 : template <typename ParallelComponent, typename DbTags, typename Metavariables,
195 : typename ArrayIndex>
196 0 : static void apply(db::DataBox<DbTags>& box,
197 : Parallel::GlobalCache<Metavariables>& cache,
198 : const ArrayIndex& /*array_index*/,
199 : const DataVector& new_measurement_timescale,
200 : const double old_measurement_expiration_time,
201 : const double new_measurement_expiration_time,
202 : DataVector control_signal,
203 : const double old_fot_expiration_time,
204 : const double new_fot_expiration_time) {
205 : auto& aggregators =
206 : db::get_mutable_reference<Tags::UpdateAggregators>(make_not_null(&box));
207 : const auto& system_to_combined_names =
208 : Parallel::get<Tags::SystemToCombinedNames>(cache);
209 : const std::string& control_system_name = ControlSystem::name();
210 : ASSERT(system_to_combined_names.count(control_system_name) == 1,
211 : "Expected name '" << control_system_name
212 : << "' to be in map of system-to-combined names, "
213 : "but it wasn't. Keys are: "
214 : << keys_of(system_to_combined_names));
215 : const std::string& combined_name =
216 : system_to_combined_names.at(control_system_name);
217 : ASSERT(aggregators.count(combined_name) == 1,
218 : "Expected combined name '" << combined_name
219 : << "' to be in map of aggregators, "
220 : "but it wasn't. Keys are: "
221 : << keys_of(aggregators));
222 :
223 : UpdateAggregator& aggregator = aggregators.at(combined_name);
224 :
225 : aggregator.insert(control_system_name, new_measurement_timescale,
226 : new_measurement_expiration_time,
227 : std::move(control_signal), new_fot_expiration_time);
228 :
229 : if (aggregator.is_ready(get<Tags::IsActiveMap>(cache))) {
230 : std::unordered_map<std::string, std::pair<DataVector, double>>
231 : combined_fot_expiration_times =
232 : aggregator.combined_fot_expiration_times(
233 : get<Tags::IsActiveMap>(cache));
234 : const std::pair<double, double> combined_measurement_expiration_time =
235 : aggregator.combined_measurement_expiration_time(
236 : get<Tags::IsActiveMap>(cache));
237 :
238 : Parallel::mutate<Tags::MeasurementTimescales, UpdateSingleFunctionOfTime>(
239 : cache, combined_name, old_measurement_expiration_time,
240 : DataVector{1, combined_measurement_expiration_time.first},
241 : combined_measurement_expiration_time.second);
242 :
243 : Parallel::mutate<::domain::Tags::FunctionsOfTime,
244 : UpdateMultipleFunctionsOfTime>(
245 : cache, old_fot_expiration_time,
246 : std::move(combined_fot_expiration_times));
247 : }
248 : }
249 : };
250 : } // namespace control_system
|