@@ -76,6 +76,7 @@ std::unique_ptr<LinOp> Gmres<ValueType>::transpose() const
7676 share (as<Transposable>(this ->get_preconditioner ())->transpose ()))
7777 .with_criteria (this ->get_stop_criterion_factory ())
7878 .with_krylov_dim (this ->get_krylov_dim ())
79+ .with_flexible (this ->get_parameters ().flexible )
7980 .on (this ->get_executor ())
8081 ->generate (
8182 share (as<Transposable>(this ->get_system_matrix ())->transpose ()));
@@ -90,6 +91,7 @@ std::unique_ptr<LinOp> Gmres<ValueType>::conj_transpose() const
9091 as<Transposable>(this ->get_preconditioner ())->conj_transpose ()))
9192 .with_criteria (this ->get_stop_criterion_factory ())
9293 .with_krylov_dim (this ->get_krylov_dim ())
94+ .with_flexible (this ->get_parameters ().flexible )
9395 .on (this ->get_executor ())
9496 ->generate (share (
9597 as<Transposable>(this ->get_system_matrix ())->conj_transpose ()));
@@ -196,7 +198,7 @@ void Gmres<ValueType>::apply_dense_impl(const VectorType* dense_b,
196198
197199 auto exec = this ->get_executor ();
198200 this ->setup_workspace ();
199-
201+ const auto is_flexible = this -> get_parameters (). flexible ;
200202 const auto num_rows = this ->get_size ()[0 ];
201203 const auto local_num_rows =
202204 ::gko::detail::get_local (dense_b)->get_size()[0];
@@ -207,6 +209,10 @@ void Gmres<ValueType>::apply_dense_impl(const VectorType* dense_b,
207209 auto krylov_bases = this ->create_workspace_op_with_type_of (
208210 ws::krylov_bases, dense_b, dim<2 >{num_rows * (krylov_dim + 1 ), num_rhs},
209211 dim<2 >{local_num_rows * (krylov_dim + 1 ), num_rhs});
212+ auto preconditioned_krylov_bases = this ->create_workspace_op_with_type_of (
213+ ws::preconditioned_krylov_bases, dense_b,
214+ dim<2 >{num_rows * (krylov_dim + 1 ), num_rhs},
215+ dim<2 >{local_num_rows * (krylov_dim + 1 ), num_rhs});
210216 // rows: rows of Hessenberg matrix, columns: block for each entry
211217 auto hessenberg = this ->template create_workspace_op <LocalVector>(
212218 ws::hessenberg, dim<2 >{krylov_dim + 1 , krylov_dim * num_rhs});
@@ -341,9 +347,16 @@ void Gmres<ValueType>::apply_dense_impl(const VectorType* dense_b,
341347 span{local_num_rows * (restart_iter + 1 ),
342348 local_num_rows * (restart_iter + 2 )},
343349 span{0 , num_rhs});
344- // preconditioned_vector = get_preconditioner() * this_krylov
350+ auto preconditioned_krylov = create_submatrix_helper (
351+ preconditioned_krylov_bases, dim<2 >{num_rows, num_rhs},
352+ span{local_num_rows * restart_iter,
353+ local_num_rows * (restart_iter + 1 )},
354+ span{0 , num_rhs});
355+ auto preconditioned_krylov_vector =
356+ is_flexible ? preconditioned_krylov.get () : preconditioned_vector;
357+ // preconditioned_krylov_vector = get_preconditioner() * this_krylov
345358 this ->get_preconditioner ()->apply (this_krylov.get (),
346- preconditioned_vector );
359+ preconditioned_krylov_vector );
347360
348361 // Create view of current column in the hessenberg matrix:
349362 // hessenberg_iter = hessenberg(:, restart_iter);
@@ -352,8 +365,8 @@ void Gmres<ValueType>::apply_dense_impl(const VectorType* dense_b,
352365 span{num_rhs * restart_iter, num_rhs * (restart_iter + 1 )});
353366
354367 // Start of Arnoldi
355- // next_krylov = A * preconditioned_vector
356- this ->get_system_matrix ()->apply (preconditioned_vector ,
368+ // next_krylov = A * preconditioned_krylov_vector
369+ this ->get_system_matrix ()->apply (preconditioned_krylov_vector ,
357370 next_krylov.get ());
358371
359372 for (size_type i = 0 ; i <= restart_iter; i++) {
@@ -414,6 +427,9 @@ void Gmres<ValueType>::apply_dense_impl(const VectorType* dense_b,
414427 auto krylov_bases_small = create_submatrix_helper (
415428 krylov_bases, dim<2 >{num_rows, num_rhs},
416429 span{0 , local_num_rows * (restart_iter + 1 )}, span{0 , num_rhs});
430+ auto preconditioned_krylov_bases_small = create_submatrix_helper (
431+ preconditioned_krylov_bases, dim<2 >{num_rows, num_rhs},
432+ span{0 , local_num_rows * (restart_iter + 1 )}, span{0 , num_rhs});
417433 auto hessenberg_small = hessenberg->create_submatrix (
418434 span{0 , restart_iter}, span{0 , num_rhs * (restart_iter)});
419435
@@ -422,15 +438,24 @@ void Gmres<ValueType>::apply_dense_impl(const VectorType* dense_b,
422438 exec->run (gmres::make_solve_krylov (
423439 residual_norm_collection, hessenberg_small.get (), y,
424440 final_iter_nums.get_const_data (), stop_status.get_const_data ()));
425- // before_preconditioner = krylov_bases * y
426- exec->run (gmres::make_multi_axpy (
427- gko::detail::get_local (krylov_bases_small.get ()), y,
428- gko::detail::get_local (before_preconditioner),
429- final_iter_nums.get_const_data (), stop_status.get_data ()));
430-
431- // x = x + get_preconditioner() * before_preconditioner
432- this ->get_preconditioner ()->apply (before_preconditioner,
433- after_preconditioner);
441+ if (is_flexible) {
442+ // after_preconditioner = preconditioned_krylov_bases * y
443+ exec->run (gmres::make_multi_axpy (
444+ gko::detail::get_local (preconditioned_krylov_bases_small.get ()), y,
445+ gko::detail::get_local (after_preconditioner),
446+ final_iter_nums.get_const_data (), stop_status.get_data ()));
447+ } else {
448+ // before_preconditioner = krylov_bases * y
449+ exec->run (gmres::make_multi_axpy (
450+ gko::detail::get_local (krylov_bases_small.get ()), y,
451+ gko::detail::get_local (before_preconditioner),
452+ final_iter_nums.get_const_data (), stop_status.get_data ()));
453+
454+ // after_preconditioner = get_preconditioner() * before_preconditioner
455+ this ->get_preconditioner ()->apply (before_preconditioner,
456+ after_preconditioner);
457+ }
458+ // x = x + after_preconditioner
434459 dense_x->add_scaled (one_op, after_preconditioner);
435460}
436461
@@ -463,7 +488,7 @@ int workspace_traits<Gmres<ValueType>>::num_arrays(const Solver&)
463488template <typename ValueType>
464489int workspace_traits<Gmres<ValueType>>::num_vectors(const Solver&)
465490{
466- return 14 ;
491+ return 15 ;
467492}
468493
469494
@@ -484,7 +509,8 @@ std::vector<std::string> workspace_traits<Gmres<ValueType>>::op_names(
484509 " after_preconditioner" ,
485510 " one" ,
486511 " minus_one" ,
487- " next_krylov_norm_tmp" };
512+ " next_krylov_norm_tmp" ,
513+ " preconditioned_krylov_bases" };
488514}
489515
490516
@@ -509,8 +535,12 @@ std::vector<int> workspace_traits<Gmres<ValueType>>::scalars(const Solver&)
509535template <typename ValueType>
510536std::vector<int > workspace_traits<Gmres<ValueType>>::vectors(const Solver&)
511537{
512- return {residual, preconditioned_vector, krylov_bases,
513- before_preconditioner, after_preconditioner};
538+ return {residual,
539+ preconditioned_vector,
540+ krylov_bases,
541+ before_preconditioner,
542+ after_preconditioner,
543+ preconditioned_krylov_bases};
514544}
515545
516546
0 commit comments