Line data Source code
1 0 : // Distributed under the MIT License. 2 : // See LICENSE.txt for details. 3 : 4 : #pragma once 5 : 6 : #include <array> 7 : #include <cstddef> 8 : #include <functional> 9 : #include <optional> 10 : #include <unordered_map> 11 : #include <unordered_set> 12 : #include <vector> 13 : 14 : #include "Domain/Block.hpp" 15 : #include "Domain/ElementDistribution.hpp" 16 : #include "Domain/Structure/ElementId.hpp" 17 : #include "Parallel/DomainDiagnosticInfo.hpp" 18 : #include "Parallel/GlobalCache.hpp" 19 : #include "Utilities/Numeric.hpp" 20 : 21 : /// \cond 22 : namespace Spectral { 23 : enum class Basis : uint8_t; 24 : enum class Quadrature : uint8_t; 25 : } // namespace Spectral 26 : /// \endcond 27 : 28 : namespace Parallel { 29 : /*! 30 : * \brief Creates elements using a chosen distribution. 31 : * 32 : * The `func` is called with `(element_id, target_proc, target_node)` allowing 33 : * the `func` to insert the element with `element_id` on the target processor 34 : * and node. 35 : */ 36 : template <typename F, size_t Dim, typename Metavariables> 37 1 : void create_elements_using_distribution( 38 : const F& func, const std::optional<domain::ElementWeight>& element_weight, 39 : const std::vector<Block<Dim>>& blocks, 40 : const std::vector<std::array<size_t, Dim>>& initial_extents, 41 : const std::vector<std::array<size_t, Dim>>& initial_refinement_levels, 42 : const Spectral::Basis i1_basis, const Spectral::Quadrature i1_quadrature, 43 : 44 : const std::unordered_set<size_t>& procs_to_ignore, 45 : const size_t number_of_procs, const size_t number_of_nodes, 46 : const size_t num_of_procs_to_use, 47 : const Parallel::GlobalCache<Metavariables>& local_cache, 48 : const bool print_diagnostics) { 49 : // Only need the element distribution if the element weight has a value 50 : // because then we have to use the space filling curve and not just use round 51 : // robin. 52 : domain::BlockZCurveProcDistribution<Dim> element_distribution{}; 53 : if (element_weight.has_value()) { 54 : const std::unordered_map<ElementId<Dim>, double> element_costs = 55 : domain::get_element_costs(blocks, initial_refinement_levels, 56 : initial_extents, element_weight.value(), 57 : i1_basis, i1_quadrature); 58 : element_distribution = domain::BlockZCurveProcDistribution<Dim>{ 59 : element_costs, num_of_procs_to_use, blocks, initial_refinement_levels, 60 : initial_extents, procs_to_ignore}; 61 : } 62 : 63 : // Will be used to print domain diagnostic info 64 : std::vector<size_t> elements_per_core(number_of_procs, 0_st); 65 : std::vector<size_t> elements_per_node(number_of_nodes, 0_st); 66 : std::vector<size_t> grid_points_per_core(number_of_procs, 0_st); 67 : std::vector<size_t> grid_points_per_node(number_of_nodes, 0_st); 68 : 69 : size_t which_proc = 0; 70 : for (const auto& block : blocks) { 71 : const auto& initial_ref_levs = initial_refinement_levels[block.id()]; 72 : const size_t grid_points_per_element = 73 : alg::accumulate(initial_extents[block.id()], 1_st, std::multiplies<>()); 74 : 75 : const std::vector<ElementId<Dim>> element_ids = 76 : initial_element_ids(block.id(), initial_ref_levs); 77 : 78 : // Value means ZCurve. nullopt means round robin 79 : if (element_weight.has_value()) { 80 : for (const auto& element_id : element_ids) { 81 : const size_t target_proc = 82 : element_distribution.get_proc_for_element(element_id); 83 : const size_t target_node = 84 : Parallel::node_of<size_t>(target_proc, local_cache); 85 : func(element_id, target_proc, target_node); 86 : 87 : ++elements_per_core[target_proc]; 88 : ++elements_per_node[target_node]; 89 : grid_points_per_core[target_proc] += grid_points_per_element; 90 : grid_points_per_node[target_node] += grid_points_per_element; 91 : } 92 : } else { 93 : for (size_t i = 0; i < element_ids.size(); ++i) { 94 : while (procs_to_ignore.find(which_proc) != procs_to_ignore.end()) { 95 : which_proc = which_proc + 1 == number_of_procs ? 0 : which_proc + 1; 96 : } 97 : const size_t target_proc = which_proc; 98 : const size_t target_node = 99 : Parallel::node_of<size_t>(which_proc, local_cache); 100 : const ElementId<Dim> element_id(element_ids[i]); 101 : func(element_id, target_proc, target_node); 102 : 103 : ++elements_per_core[which_proc]; 104 : ++elements_per_node[target_node]; 105 : grid_points_per_core[which_proc] += grid_points_per_element; 106 : grid_points_per_node[target_node] += grid_points_per_element; 107 : 108 : which_proc = which_proc + 1 == number_of_procs ? 0 : which_proc + 1; 109 : } 110 : } 111 : } 112 : 113 : if (print_diagnostics) { 114 : Parallel::printf("\n%s\n", domain::diagnostic_info( 115 : blocks.size(), local_cache, 116 : elements_per_core, elements_per_node, 117 : grid_points_per_core, grid_points_per_node)); 118 : } 119 : } 120 : } // namespace Parallel