@@ -52,7 +52,7 @@ namespace multigrid {
5252GKO_REGISTER_OPERATION (initialize_v, multigrid::initialize_v);
5353
5454
55- }
55+ } // namespace multigrid
5656
5757
5858template <typename ValueType>
@@ -217,9 +217,117 @@ void Multigrid<ValueType>::prepare_vcycle(
217217}
218218
219219
220+ template <typename ValueType>
221+ void Multigrid<ValueType>::run_cycle(
222+ multigrid_cycle cycle, size_type level, std::shared_ptr<const LinOp> matrix,
223+ const matrix::Dense<ValueType> *b, matrix::Dense<ValueType> *x,
224+ std::vector<std::shared_ptr<matrix::Dense<ValueType>>> &r_list,
225+ std::vector<std::shared_ptr<matrix::Dense<ValueType>>> &g_list,
226+ std::vector<std::shared_ptr<matrix::Dense<ValueType>>> &e_list) const
227+ {
228+ auto r = r_list.at (level);
229+ auto g = g_list.at (level);
230+ auto e = e_list.at (level);
231+ r->copy_from (b);
232+ matrix->apply (neg_one_op_.get (), x, one_op_.get (), r.get ());
233+ // x += relaxation * Smoother(r)
234+ auto pre_smoother = pre_smoother_list_.at (level);
235+ std::shared_ptr<matrix::Dense<ValueType>> pre_relaxation;
236+ if (parameters_.pre_relaxation .get_num_elems () == 0 ) {
237+ pre_relaxation = one_op_;
238+ } else {
239+ pre_relaxation = pre_relaxation_list_.at (level);
240+ }
241+ if (pre_smoother) {
242+ pre_smoother->apply (pre_relaxation.get (), r.get (), one_op_.get (), x);
243+ // compute residual
244+ r->copy_from (b); // n * b
245+ matrix->apply (neg_one_op_.get (), x, one_op_.get (), r.get ());
246+ }
247+ // first cycle
248+ rstr_prlg_list_.at (level)->restrict_apply (r.get (), g.get ());
249+ // next level or solve it
250+ if (level + 1 == rstr_prlg_list_.size ()) {
251+ coarsest_solver_->apply (g.get (), e.get ());
252+ } else {
253+ this ->run_cycle (cycle, level + 1 ,
254+ rstr_prlg_list_.at (level)->get_coarse_operator (),
255+ g.get (), e.get (), r_list, g_list, e_list);
256+ }
257+ // additional work for non-v_cycle
258+ if (cycle == multigrid_cycle::f || cycle == multigrid_cycle::w) {
259+ // second cycle - f_cycle, w_cycle
260+ // prolong
261+ rstr_prlg_list_.at (level)->prolong_applyadd (e.get (), x);
262+ // compute residual
263+ r->copy_from (b); // n * b
264+ matrix->apply (neg_one_op_.get (), x, one_op_.get (), r.get ());
265+ // re-smooth
266+ if (pre_smoother) {
267+ pre_smoother->apply (pre_relaxation.get (), r.get (), one_op_.get (),
268+ x);
269+ // compute residual
270+ r->copy_from (b); // n * b
271+ matrix->apply (neg_one_op_.get (), x, one_op_.get (), r.get ());
272+ }
273+
274+ rstr_prlg_list_.at (level)->restrict_apply (r.get (), g.get ());
275+ // next level or solve it
276+ if (level + 1 == rstr_prlg_list_.size ()) {
277+ coarsest_solver_->apply (g.get (), e.get ());
278+ } else {
279+ if (cycle == multigrid_cycle::f) {
280+ // f_cycle call v_cycle in the second cycle
281+ this ->run_cycle (
282+ multigrid_cycle::v, level + 1 ,
283+ rstr_prlg_list_.at (level)->get_coarse_operator (), g.get (),
284+ e.get (), r_list, g_list, e_list);
285+ } else {
286+ this ->run_cycle (
287+ cycle, level + 1 ,
288+ rstr_prlg_list_.at (level)->get_coarse_operator (), g.get (),
289+ e.get (), r_list, g_list, e_list);
290+ }
291+ }
292+ } else if (cycle == multigrid_cycle::kfcg ||
293+ cycle == multigrid_cycle::kgcr) {
294+ // do some work in coarse level - do not need prolong
295+ GKO_NOT_IMPLEMENTED ;
296+ }
297+
298+ // prolong
299+ rstr_prlg_list_.at (level)->prolong_applyadd (e.get (), x);
300+
301+ // post-smooth
302+ std::shared_ptr<LinOp> post_smoother;
303+ std::shared_ptr<matrix::Dense<ValueType>> post_relaxation;
304+
305+ if (parameters_.post_uses_pre ) {
306+ post_smoother = pre_smoother;
307+ post_relaxation = pre_relaxation;
308+ } else {
309+ post_smoother = post_smoother_list_.at (level);
310+ if (parameters_.post_relaxation .get_num_elems () == 0 ) {
311+ post_relaxation = one_op_;
312+ } else {
313+ post_relaxation = post_relaxation_list_.at (level);
314+ }
315+ }
316+ if (post_smoother) {
317+ r->copy_from (b);
318+ matrix->apply (neg_one_op_.get (), x, one_op_.get (), r.get ());
319+ post_smoother->apply (post_relaxation.get (), r.get (), one_op_.get (), x);
320+ }
321+ }
322+
323+
220324template <typename ValueType>
221325void Multigrid<ValueType>::apply_impl(const LinOp *b, LinOp *x) const
222326{
327+ if (cycle_ == multigrid_cycle::kfcg || cycle_ == multigrid_cycle::kgcr) {
328+ GKO_NOT_IMPLEMENTED ;
329+ }
330+
223331 auto exec = this ->get_executor ();
224332 constexpr uint8 RelativeStoppingId{1 };
225333 Array<stopping_status> stop_status (exec, b->get_size ()[1 ]);
@@ -231,33 +339,27 @@ void Multigrid<ValueType>::apply_impl(const LinOp *b, LinOp *x) const
231339 auto stop_criterion = stop_criterion_factory_->generate (
232340 system_matrix_, std::shared_ptr<const LinOp>(b, [](const LinOp *) {}),
233341 x, r.get ());
234- if (1 ) {
235- std::vector<std::shared_ptr<vector_type>> r_list (
236- rstr_prlg_list_.size ());
237- std::vector<std::shared_ptr<vector_type>> g_list (
238- rstr_prlg_list_.size ());
239- std::vector<std::shared_ptr<vector_type>> e_list (
240- rstr_prlg_list_.size ());
241- this ->prepare_vcycle (b->get_size ()[1 ], r_list, g_list, e_list);
242- exec->run (multigrid::make_initialize_v (e_list, &stop_status));
243- int iter = -1 ;
244- while (true ) {
245- ++iter;
246- this ->template log <log::Logger::iteration_complete>(
247- this , iter, r.get (), dense_x);
248- if (stop_criterion->update ()
249- .num_iterations (iter)
250- .residual (r.get ())
251- .solution (dense_x)
252- .check (RelativeStoppingId, true , &stop_status,
253- &one_changed)) {
254- break ;
255- }
256- this ->v_cycle (0 , system_matrix_, dense_b, dense_x, r_list, g_list,
257- e_list);
258- r->copy_from (dense_b);
259- system_matrix_->apply (neg_one_op_.get (), x, one_op_.get (), r.get ());
342+ std::vector<std::shared_ptr<vector_type>> r_list (rstr_prlg_list_.size ());
343+ std::vector<std::shared_ptr<vector_type>> g_list (rstr_prlg_list_.size ());
344+ std::vector<std::shared_ptr<vector_type>> e_list (rstr_prlg_list_.size ());
345+ this ->prepare_vcycle (b->get_size ()[1 ], r_list, g_list, e_list);
346+ exec->run (multigrid::make_initialize_v (e_list, &stop_status));
347+ int iter = -1 ;
348+ while (true ) {
349+ ++iter;
350+ this ->template log <log::Logger::iteration_complete>(this , iter, r.get (),
351+ dense_x);
352+ if (stop_criterion->update ()
353+ .num_iterations (iter)
354+ .residual (r.get ())
355+ .solution (dense_x)
356+ .check (RelativeStoppingId, true , &stop_status, &one_changed)) {
357+ break ;
260358 }
359+ this ->run_cycle (cycle_, 0 , system_matrix_, dense_b, dense_x, r_list,
360+ g_list, e_list);
361+ r->copy_from (dense_b);
362+ system_matrix_->apply (neg_one_op_.get (), x, one_op_.get (), r.get ());
261363 }
262364}
263365
0 commit comments