Line data Source code
1 0 : // Distributed under the MIT License.
2 : // See LICENSE.txt for details.
3 :
4 : #pragma once
5 :
6 : #include <memory>
7 : #include <pup.h>
8 : #include <string>
9 : #include <unordered_map>
10 : #include <unordered_set>
11 : #include <utility>
12 :
13 : #include "DataStructures/DataBox/DataBox.hpp"
14 : #include "DataStructures/DataVector.hpp"
15 : #include "Domain/FunctionsOfTime/FunctionOfTime.hpp"
16 : #include "Parallel/GlobalCache.hpp"
17 : #include "Utilities/ErrorHandling/Assert.hpp"
18 : #include "Utilities/TMPL.hpp"
19 : #include "Utilities/TaggedTuple.hpp"
20 :
21 : /// \cond
22 : namespace domain::Tags {
23 : struct FunctionsOfTime;
24 : } // namespace domain::Tags
25 : namespace control_system::Tags {
26 : struct UpdateAggregators;
27 : struct SystemToCombinedNames;
28 : struct MeasurementTimescales;
29 : } // namespace control_system::Tags
30 : /// \endcond
31 :
32 : namespace control_system {
33 : /// \ingroup ControlSystemGroup
34 : /// Updates a FunctionOfTime in the global cache. Intended to be used in
35 : /// Parallel::mutate.
36 1 : struct UpdateSingleFunctionOfTime {
37 0 : static void apply(
38 : gsl::not_null<std::unordered_map<
39 : std::string,
40 : std::unique_ptr<domain::FunctionsOfTime::FunctionOfTime>>*>
41 : f_of_t_list,
42 : const std::string& f_of_t_name, double update_time,
43 : DataVector update_deriv, double new_expiration_time);
44 : };
45 :
46 : /*!
47 : * \ingroup ControlSystemGroup
48 : * \brief Updates several FunctionOfTimes in the global cache at once. Intended
49 : * to be used in Parallel::mutate.
50 : *
51 : * \details All functions of time are updated at the same `update_time`. For the
52 : * `update_args`, the keys of the map are the names of the functions of time.
53 : * The value `std::pair<DataVector, double>` for each key is the updated
54 : * derivative for the function of time and the new expiration time,
55 : * respectively.
56 : */
57 1 : struct UpdateMultipleFunctionsOfTime {
58 0 : static void apply(
59 : gsl::not_null<std::unordered_map<
60 : std::string,
61 : std::unique_ptr<domain::FunctionsOfTime::FunctionOfTime>>*>
62 : f_of_t_list,
63 : double update_time,
64 : const std::unordered_map<std::string, std::pair<DataVector, double>>&
65 : update_args);
66 : };
67 :
68 : /*!
69 : * \brief A class for collecting and storing information related to updating
70 : * functions of time and measurement timescales.
71 : *
72 : * \details This class also determines if enough data has been received in order
73 : * for the functions of time and measurement timescales to be updated. There
74 : * should be one `UpdateAggregator` for every group of control systems that have
75 : * the same `control_system::protocols::Measurement`.
76 : */
77 1 : struct UpdateAggregator {
78 0 : UpdateAggregator() = default;
79 :
80 : /*!
81 : * \brief Construct a new UpdateAggregator using a set of active control
82 : * system names and the combined name for the measurement.
83 : *
84 : * It is expected that all the control systems in this set use the same
85 : * `control_system::protocols::Measurement`.
86 : */
87 1 : UpdateAggregator(std::string combined_name,
88 : std::unordered_set<std::string> active_control_system_names);
89 :
90 : /*!
91 : * \brief Inserts and stores information for one of the control systems that
92 : * this class was constructed with.
93 : *
94 : * \param control_system_name Name of control system to add information for
95 : * \param new_measurement_timescale DataVector of new measurement timescales
96 : * calculated from a control system update.
97 : * \param new_measurement_expiration_time New measurement expiration time
98 : * calculated during that update.
99 : * \param control_signal New highest derivative for the function of time
100 : * calculated during that update (will be `std::move`ed).
101 : * \param new_fot_expiration_time New function of time expiration time
102 : * calculated for during that update
103 : */
104 1 : void insert(const std::string& control_system_name,
105 : const DataVector& new_measurement_timescale,
106 : double new_measurement_expiration_time, DataVector control_signal,
107 : double new_fot_expiration_time);
108 :
109 : /*!
110 : * \brief Checks if `insert` has been called for all control systems that this
111 : * class was constructed with.
112 : */
113 1 : bool is_ready() const;
114 :
115 : /*!
116 : * \brief Returns a sorted concatenation of the control system names this
117 : * class was constructed with.
118 : */
119 1 : const std::string& combined_name() const;
120 :
121 : /*!
122 : * \brief Once `is_ready` is true, returns a map between the control system
123 : * name and a `std::pair` containing the `control_signal` that was passed to
124 : * `insert` and the minimum of all the `new_fot_expiration_time`s passed to
125 : * `insert`.
126 : *
127 : * \details This function is expected to only be called when `is_ready` is
128 : * true. It also must be called before `combined_measurement_expiration_time`.
129 : */
130 : std::unordered_map<std::string, std::pair<DataVector, double>>
131 1 : combined_fot_expiration_times() const;
132 :
133 : /*!
134 : * \brief Once `is_ready` is true, returns a `std::pair` containing the
135 : * minimum of all `new_measurement_timescale`s passed to `insert` and the
136 : * minimum of all `new_measurement_expiration_time`s passed to `insert`.
137 : *
138 : * \details This function is expected to be called only when `is_ready` is
139 : * true and only a single time once all control active control
140 : * systems for this measurement have computed their update values. It also
141 : * must be called after `combined_fot_expiration_times`. This function clears
142 : * all stored data when it is called.
143 : */
144 1 : std::pair<double, double> combined_measurement_expiration_time();
145 :
146 : /// \cond
147 : void pup(PUP::er& p);
148 : /// \endcond
149 :
150 : private:
151 : std::unordered_map<std::string, std::pair<std::pair<DataVector, double>,
152 : std::pair<double, double>>>
153 0 : expiration_times_{};
154 0 : std::unordered_set<std::string> active_names_{};
155 0 : std::string combined_name_{};
156 : };
157 :
158 : /*!
159 : * \brief Simple action that updates the appropriate `UpdateAggregator` for the
160 : * templated `ControlSystem`.
161 : *
162 : * \details The `new_measurement_timescale`, `new_measurement_expiration_time`,
163 : * `control_signal`, and `new_fot_expiration_time` are passed along to the
164 : * `UpdateAggregator::insert` function. The `old_measurement_expiration_time`
165 : * and `old_fot_expiration_time` are only used when the
166 : * `UpdateAggregator::is_ready` in order to update the functions of time and the
167 : * measurement timescale.
168 : *
169 : * When the `UpdateAggregator::is_ready`, the measurement timescale is mutated
170 : * with `UpdateSingleFunctionOfTime` and the functions of time are mutated with
171 : * `UpdateMultipleFunctionsOfTime`, both using `Parallel::mutate`.
172 : *
173 : * The "appropriate" `UpdateAggregator` is chosen from the
174 : * `control_system::Tags::SystemToCombinedNames` for the templated
175 : * `ControlSystem`.
176 : */
177 : template <typename ControlSystem>
178 1 : struct AggregateUpdate {
179 : template <typename ParallelComponent, typename DbTags, typename Metavariables,
180 : typename ArrayIndex>
181 0 : static void apply(db::DataBox<DbTags>& box,
182 : Parallel::GlobalCache<Metavariables>& cache,
183 : const ArrayIndex& /*array_index*/,
184 : const DataVector& new_measurement_timescale,
185 : const double old_measurement_expiration_time,
186 : const double new_measurement_expiration_time,
187 : DataVector control_signal,
188 : const double old_fot_expiration_time,
189 : const double new_fot_expiration_time) {
190 : auto& aggregators =
191 : db::get_mutable_reference<Tags::UpdateAggregators>(make_not_null(&box));
192 : const auto& system_to_combined_names =
193 : Parallel::get<Tags::SystemToCombinedNames>(cache);
194 : const std::string& control_system_name = ControlSystem::name();
195 : ASSERT(system_to_combined_names.count(control_system_name) == 1,
196 : "Expected name '" << control_system_name
197 : << "' to be in map of system-to-combined names, "
198 : "but it wasn't. Keys are: "
199 : << keys_of(system_to_combined_names));
200 : const std::string& combined_name =
201 : system_to_combined_names.at(control_system_name);
202 : ASSERT(aggregators.count(combined_name) == 1,
203 : "Expected combined name '" << combined_name
204 : << "' to be in map of aggregators, "
205 : "but it wasn't. Keys are: "
206 : << keys_of(aggregators));
207 :
208 : UpdateAggregator& aggregator = aggregators.at(combined_name);
209 :
210 : aggregator.insert(control_system_name, new_measurement_timescale,
211 : new_measurement_expiration_time,
212 : std::move(control_signal), new_fot_expiration_time);
213 :
214 : if (aggregator.is_ready()) {
215 : std::unordered_map<std::string, std::pair<DataVector, double>>
216 : combined_fot_expiration_times =
217 : aggregator.combined_fot_expiration_times();
218 : const std::pair<double, double> combined_measurement_expiration_time =
219 : aggregator.combined_measurement_expiration_time();
220 :
221 : Parallel::mutate<Tags::MeasurementTimescales, UpdateSingleFunctionOfTime>(
222 : cache, combined_name, old_measurement_expiration_time,
223 : DataVector{1, combined_measurement_expiration_time.first},
224 : combined_measurement_expiration_time.second);
225 :
226 : Parallel::mutate<::domain::Tags::FunctionsOfTime,
227 : UpdateMultipleFunctionsOfTime>(
228 : cache, old_fot_expiration_time,
229 : std::move(combined_fot_expiration_times));
230 : }
231 : }
232 : };
233 : } // namespace control_system
|