Skip to content

Commit 2205766

Browse files
Thomas Grützmacherpratikvnyhmtsai
committed
Review updates
Fix GMRES reference test initialization and improve memory read efficiency of hessenberg_qr. Co-authored-by: Pratik Nayak <pratikvn@protonmail.com> Co-authored-by: Yuhsiang M. Tsai <yhmtsai@gmail.com>
1 parent 6164a27 commit 2205766

3 files changed

Lines changed: 30 additions & 27 deletions

File tree

common/unified/solver/common_gmres_kernels.cpp

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -107,44 +107,48 @@ void hessenberg_qr(std::shared_ptr<const DefaultExecutor> exec,
107107
}
108108
// increment iteration count
109109
final_iter_nums[rhs]++;
110+
auto hess_this = hessenberg_iter(0, rhs);
111+
auto hess_next = hessenberg_iter(1, rhs);
110112
// apply previous Givens rotations to column
111113
for (int64 j = 0; j < iter; ++j) {
112-
auto out1 = givens_cos(j, rhs) * hessenberg_iter(j, rhs) +
113-
givens_sin(j, rhs) * hessenberg_iter(j + 1, rhs);
114-
auto out2 =
115-
-conj(givens_sin(j, rhs)) * hessenberg_iter(j, rhs) +
116-
conj(givens_cos(j, rhs)) * hessenberg_iter(j + 1, rhs);
114+
// in here: hess_this = hessenberg_iter(j, rhs);
115+
// hess_next = hessenberg_iter(j+1, rhs);
116+
hess_next = hessenberg_iter(j + 1, rhs);
117+
const auto gc = givens_cos(j, rhs);
118+
const auto gs = givens_sin(j, rhs);
119+
const auto out1 = gc * hess_this + gs * hess_next;
120+
const auto out2 = -conj(gs) * hess_this + conj(gc) * hess_next;
117121
hessenberg_iter(j, rhs) = out1;
118-
hessenberg_iter(j + 1, rhs) = out2;
122+
hessenberg_iter(j + 1, rhs) = hess_this = out2;
123+
hess_next = hessenberg_iter(j + 2, rhs);
119124
}
125+
// hess_this is hessenberg_iter(iter, rhs) and
126+
// hess_next is hessenberg_iter(iter + 1, rhs)
127+
auto gs = givens_sin(iter, rhs);
128+
auto gc = givens_cos(iter, rhs);
120129
// compute new Givens rotation
121-
if (hessenberg_iter(iter, rhs) == zero<value_type>()) {
122-
givens_cos(iter, rhs) = zero<value_type>();
123-
givens_sin(iter, rhs) = one<value_type>();
130+
if (hess_this == zero<value_type>()) {
131+
givens_cos(iter, rhs) = gc = zero<value_type>();
132+
givens_sin(iter, rhs) = gs = one<value_type>();
124133
} else {
125-
const auto this_hess = hessenberg_iter(iter, rhs);
126-
const auto next_hess = hessenberg_iter(iter + 1, rhs);
127-
const auto scale = abs(this_hess) + abs(next_hess);
134+
const auto scale = abs(hess_this) + abs(hess_next);
128135
const auto hypotenuse =
129136
scale *
130-
sqrt(abs(this_hess / scale) * abs(this_hess / scale) +
131-
abs(next_hess / scale) * abs(next_hess / scale));
132-
givens_cos(iter, rhs) = conj(this_hess) / hypotenuse;
133-
givens_sin(iter, rhs) = conj(next_hess) / hypotenuse;
137+
sqrt(abs(hess_this / scale) * abs(hess_this / scale) +
138+
abs(hess_next / scale) * abs(hess_next / scale));
139+
givens_cos(iter, rhs) = gc = conj(hess_this) / hypotenuse;
140+
givens_sin(iter, rhs) = gs = conj(hess_next) / hypotenuse;
134141
}
135142
// apply new Givens rotation to column
136-
hessenberg_iter(iter, rhs) =
137-
givens_cos(iter, rhs) * hessenberg_iter(iter, rhs) +
138-
givens_sin(iter, rhs) * hessenberg_iter(iter + 1, rhs);
143+
hessenberg_iter(iter, rhs) = gc * hess_this + gs * hess_next;
139144
hessenberg_iter(iter + 1, rhs) = zero<value_type>();
140145
// apply new Givens rotation to RHS of least-squares problem
141-
residual_norm_collection(iter + 1, rhs) =
142-
-conj(givens_sin(iter, rhs)) *
143-
residual_norm_collection(iter, rhs);
146+
const auto rnc_new =
147+
-conj(gs) * residual_norm_collection(iter, rhs);
148+
residual_norm_collection(iter + 1, rhs) = rnc_new;
144149
residual_norm_collection(iter, rhs) =
145-
givens_cos(iter, rhs) * residual_norm_collection(iter, rhs);
146-
residual_norm(0, rhs) =
147-
abs(residual_norm_collection(iter + 1, rhs));
150+
gc * residual_norm_collection(iter, rhs);
151+
residual_norm(0, rhs) = abs(rnc_new);
148152
},
149153
hessenberg_iter->get_size()[1], givens_sin, givens_cos, residual_norm,
150154
residual_norm_collection, hessenberg_iter, iter, final_iter_nums,

core/solver/gmres.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ void Gmres<ValueType>::apply_dense_impl(const matrix::Dense<ValueType>* dense_b,
159159
// rows: rows of Hessenberg matrix, columns: block for each entry
160160
auto hessenberg = this->template create_workspace_op<Vector>(
161161
ws::hessenberg, dim<2>{krylov_dim + 1, krylov_dim * num_rhs});
162-
hessenberg->fill(0);
163162
auto givens_sin = this->template create_workspace_op<Vector>(
164163
ws::givens_sin, dim<2>{krylov_dim, num_rhs});
165164
auto givens_cos = this->template create_workspace_op<Vector>(

reference/test/solver/gmres_kernels.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ TYPED_TEST(Gmres, KernelSolveKrylov)
305305
// clang-format off
306306
{{-1, 3, 2, -4},
307307
{0, 0, 1, 5},
308-
{nan, nan, nan}},
308+
{nan, nan, nan, nan}},
309309
// clang-format on
310310
this->exec);
311311
this->small_residual_norm_collection =

0 commit comments

Comments
 (0)