From 4951f117f3929440a2d0a24558b332f0ac53baeb Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 21 Jun 2024 10:36:35 -0700 Subject: [PATCH 001/103] Fix warning in spin taste and minor cleanup --- include/kernels/spin_taste.cuh | 16 +++++++--------- lib/spin_taste.cu | 16 +++++----------- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/include/kernels/spin_taste.cuh b/include/kernels/spin_taste.cuh index bcd2460e2f..1c8ef0dbf6 100644 --- a/include/kernels/spin_taste.cuh +++ b/include/kernels/spin_taste.cuh @@ -19,16 +19,14 @@ namespace quda F out; /** output vector field */ const F in; /** input vector field */ - SpinTasteArg(ColorSpinorField &out_, const ColorSpinorField &in_) : - kernel_param(dim3(in_.VolumeCB(), in_.SiteSubset(), 1)), out(out_), in(in_) + SpinTasteArg(ColorSpinorField &out, const ColorSpinorField &in) : + kernel_param(dim3(in.VolumeCB(), in.SiteSubset(), 1)), out(out), in(in) { - checkOrder(out_, in_); // check all orders match - checkPrecision(out_, in_); // check all precisions match - checkLocation(out_, in_); // check all locations match - if (!in_.isNative()) errorQuda("Unsupported field order colorspinor= %d \n", in_.FieldOrder()); - if (!out_.isNative()) errorQuda("Unsupported field order colorspinor= %d \n", out_.FieldOrder()); -#pragma unroll - for (int i = 0; i < 4; i++) { X[i] = in_.X()[i]; } + checkOrder(out, in); // check all orders match + checkPrecision(out, in); // check all precisions match + checkLocation(out, in); // check all locations match + checkNative(out, in); + for (int i = 0; i < 4; i++) { X[i] = in.X()[i]; } } }; diff --git a/lib/spin_taste.cu b/lib/spin_taste.cu index 824665f51e..0aefa2042e 100644 --- a/lib/spin_taste.cu +++ b/lib/spin_taste.cu @@ -67,22 +67,16 @@ namespace quda void preTune() { out.backup(); } void postTune() { out.restore(); } - long long flops() const { return 0; } long long bytes() const { return 2 * in.Bytes(); } }; -#ifdef GPU_STAGGERED_DIRAC void applySpinTaste(ColorSpinorField &out, const ColorSpinorField &in, QudaSpinTasteGamma gamma) { - instantiate(out, in, gamma); - //// ensure that ghosts are updated if needed - // if (u.GhostExchange() == QUDA_GHOST_EXCHANGE_PAD) u.exchangeGhost(); - } -#else - void applySpinTaste(ColorSpinorField &out, const ColorSpinorField &in, QudaSpinTasteGamma gamma) - { - errorQuda("Gauge tools are not build"); + if constexpr(is_enabled()) { + instantiate(out, in, gamma); + } else { + errorQuda("Staggered operator has not been built"); + } } -#endif } // namespace quda From 7c4793bf538c59aa933bb6b2bbfb75f937ea5108 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 21 Jun 2024 11:05:49 -0700 Subject: [PATCH 002/103] Some cleanup of CG interface --- include/invert_quda.h | 11 ++++++++--- lib/interface_quda.cpp | 2 +- lib/inv_cg_quda.cpp | 8 ++++---- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/include/invert_quda.h b/include/invert_quda.h index 816313fbd4..8cd4730d68 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -724,12 +724,16 @@ namespace quda { CG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); virtual ~CG(); + /** * @brief Run CG. * @param out Solution vector. * @param in Right-hand side. */ - void operator()(ColorSpinorField &out, ColorSpinorField &in) override { (*this)(out, in, nullptr, 0.0); }; + void operator()(ColorSpinorField &out, ColorSpinorField &in) override + { + (*this)(out, in, ColorSpinorField(), 0.0); + }; /** * @brief Solve re-using an initial Krylov space defined by an initial r2_old_init and search direction p_init. @@ -739,7 +743,8 @@ namespace quda { * @param p_init Initial-search direction. * @param r2_old_init [description] */ - void operator()(ColorSpinorField &out, ColorSpinorField &in, ColorSpinorField *p_init, double r2_old_init); + void operator()(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &p_init, + double r2_old_init); void blocksolve(ColorSpinorField &out, ColorSpinorField &in) override; @@ -758,7 +763,7 @@ namespace quda { * @param out Solution-vector. * @param in Right-hand side. */ - void hqsolve(ColorSpinorField &out, ColorSpinorField &in); + void hqsolve(ColorSpinorField &out, const ColorSpinorField &in); }; class CGNE : public CG diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 0b1bfe5473..b7c046f955 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -4041,7 +4041,7 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) { CG cg(*m, *mSloppy, *mSloppy, *mSloppy, solverParam); if (i==0) - cg(x[i], b, &p[i], r2_old[i]); + cg(x[i], b, p[i], r2_old[i]); else cg(x[i], b); } diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index 19b1d1225d..0951419774 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -206,7 +206,7 @@ namespace quda { } } - void CG::operator()(ColorSpinorField &x, ColorSpinorField &b, ColorSpinorField *p_init, double r2_old_init) + void CG::operator()(ColorSpinorField &x, const ColorSpinorField &b, const ColorSpinorField &p_init, double r2_old_init) { if (param.is_preconditioner) commGlobalReductionPush(param.global_reduction); @@ -315,10 +315,10 @@ namespace quda { ColorSpinorParam csParam(rSloppy); csParam.create = QUDA_NULL_FIELD_CREATE; - XUpdateBatch x_update_batch(Np, p_init ? *p_init : rSloppy, csParam); + XUpdateBatch x_update_batch(Np, !p_init.empty() ? p_init : rSloppy, csParam); double r2_old = 0.0; - if (r2_old_init != 0.0 and p_init) { + if (r2_old_init != 0.0 and !p_init.empty()) { r2_old = r2_old_init; Complex rp = blas::cDotProduct(rSloppy, x_update_batch.get_current_field()) / (r2); blas::caxpy(-rp, rSloppy, x_update_batch.get_current_field()); @@ -541,7 +541,7 @@ namespace quda { } // Separate HQ residual codepath - void CG::hqsolve(ColorSpinorField &x, ColorSpinorField &b) + void CG::hqsolve(ColorSpinorField &x, const ColorSpinorField &b) { logQuda(QUDA_VERBOSE, "Performing a HQ CG solve\n"); From 9a2190c4791f938da71b11396fa07b31604d0506 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 21 Jun 2024 14:23:58 -0700 Subject: [PATCH 003/103] Add MRHS interface for all solvers, and mandate source vector is const --- include/accelerator.h | 9 ++- include/invert_quda.h | 151 +++++++++++++++++++++++++++++-------- include/multigrid.h | 7 +- lib/inv_bicgstab_quda.cpp | 4 +- lib/inv_bicgstabl_quda.cpp | 2 +- lib/inv_ca_cg.cpp | 6 +- lib/inv_ca_gcr.cpp | 2 +- lib/inv_cg3_quda.cpp | 6 +- lib/inv_cg_quda.cpp | 4 +- lib/inv_eigcg_quda.cpp | 10 ++- lib/inv_gcr_quda.cpp | 17 +++-- lib/inv_gmresdr_quda.cpp | 2 +- lib/inv_mr_quda.cpp | 2 +- lib/inv_pcg_quda.cpp | 3 +- lib/inv_sd_quda.cpp | 2 +- lib/multigrid.cpp | 11 +-- 16 files changed, 170 insertions(+), 68 deletions(-) diff --git a/include/accelerator.h b/include/accelerator.h index 2711859663..53b5a811c7 100644 --- a/include/accelerator.h +++ b/include/accelerator.h @@ -46,7 +46,12 @@ namespace quda * @param out Solution vector. * @param in Right-hand side. */ - virtual void operator()(ColorSpinorField &out, ColorSpinorField &in) + virtual void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in) { if (transformer.trained) { transformer.apply(*base_solver, out, in); @@ -64,7 +69,7 @@ namespace quda * @param null Solver to solve for null vectors. * @param in meta color spinor field. */ - virtual void train_param(Solver &null, ColorSpinorField &in) + virtual void train_param(Solver &null, const ColorSpinorField &in) { if (!active_training && !transformer.trained) { active_training = true; diff --git a/include/invert_quda.h b/include/invert_quda.h index 8cd4730d68..19d02a2635 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -18,8 +18,10 @@ namespace quda { #ifdef __CUDACC__ #ifdef __NVCC_DIAG_PRAGMA_SUPPORT__ #pragma nv_diag_suppress 611 +#pragma nv_diag_suppress 997 #else #pragma diag_suppress 611 +#pragma diag_suppress 997 #endif #endif @@ -427,15 +429,10 @@ namespace quda { const DiracMatrix &matEig, SolverParam ¶m); virtual ~Solver(); - virtual void operator()(ColorSpinorField &out, ColorSpinorField &in) = 0; - /** @brief Naive loop over RHS, for solvers that are not yet multi-RHS aware */ - void operator()(cvector_ref &out, cvector_ref &in) - { - for (auto i = 0u; i < in.size(); i++) { this->operator()(out[i], in[i]); } - } + virtual void operator()(cvector_ref &out, cvector_ref &in) = 0; virtual void blocksolve(ColorSpinorField &out, ColorSpinorField &in); @@ -455,7 +452,7 @@ namespace quda { @param Solver the solver to be used to collect the null space vectors. @param ColorSpinorField the vector used to perform the training. */ - virtual void train_param(Solver &, ColorSpinorField &) + virtual void train_param(Solver &, const ColorSpinorField &) { // Do nothing } @@ -464,7 +461,8 @@ namespace quda { @brief a virtual method that performs the inversion and collect some vectors. The default here is a no-op and should not be called. */ - virtual void solve_and_collect(ColorSpinorField &, ColorSpinorField &, cvector_ref &, int, double) + virtual void solve_and_collect(ColorSpinorField &, const ColorSpinorField &, cvector_ref &, int, + double) { errorQuda("NOT implemented."); } @@ -730,9 +728,9 @@ namespace quda { * @param out Solution vector. * @param in Right-hand side. */ - void operator()(ColorSpinorField &out, ColorSpinorField &in) override + void operator()(cvector_ref &out, cvector_ref &in) override { - (*this)(out, in, ColorSpinorField(), 0.0); + for (auto i = 0u; i < in.size(); i++) (*this)(out[i], in[i], ColorSpinorField(), 0.0); }; /** @@ -789,7 +787,12 @@ namespace quda { CGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); /** @return Return the residual vector from the prior solve @@ -823,7 +826,12 @@ namespace quda { CGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); /** @return Return the residual vector from the prior solve @@ -859,7 +867,12 @@ namespace quda { public: CG3(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); /** @return Return the residual vector from the prior solve @@ -892,7 +905,12 @@ namespace quda { public: CG3NE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); /** @return Return the residual vector from the prior solve @@ -924,7 +942,12 @@ namespace quda { public: CG3NR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m); - void operator()(ColorSpinorField &out, ColorSpinorField &in); + void operator()(cvector &out, cvector_ref &in) + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); /** @return Return the residual vector from the prior solve @@ -980,9 +1003,10 @@ namespace quda { virtual ~PreconCG(); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override + void operator()(cvector_ref &out, cvector_ref &in) override { - this->solve_and_collect(out, in, cvector_ref(), 0, 0); + for (auto i = 0u; i < in.size(); i++) + this->solve_and_collect(out[i], in[i], cvector_ref(), 0, 0); } /** @@ -993,8 +1017,8 @@ namespace quda { @param collect_miniter minimal iteration start from which the r vectors are to be collected @param collect_tol maxiter tolerance start from which the r vectors are to be collected */ - virtual void solve_and_collect(ColorSpinorField &out, ColorSpinorField &in, cvector_ref &v_r, - int collect_miniter, double collect_tol) override; + virtual void solve_and_collect(ColorSpinorField &out, const ColorSpinorField &in, + cvector_ref &v_r, int collect_miniter, double collect_tol) override; virtual bool hermitian() const override { return true; } /** PCG is only Hermitian system */ @@ -1029,7 +1053,12 @@ namespace quda { const DiracMatrix &matEig, SolverParam ¶m); virtual ~BiCGstab(); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); /** @return Return the residual vector from the prior solve @@ -1140,7 +1169,12 @@ namespace quda { BiCGstabL(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matEig, SolverParam ¶m); virtual ~BiCGstabL(); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); virtual bool hermitian() const override { return false; } /** BiCGStab is for any linear system */ @@ -1172,6 +1206,7 @@ namespace quda { ColorSpinorField r; //! residual vector ColorSpinorField r_sloppy; //! sloppy residual vector + int k_break = 0; //! track when the solver converged std::vector p; // GCR direction vectors std::vector Ap; // mat * direction vectors @@ -1201,7 +1236,17 @@ namespace quda { const DiracMatrix &matEig, SolverParam ¶m); virtual ~GCR(); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); + + /** + @return Return the residual vector from the prior solve + */ + ColorSpinorField &get_residual() override; virtual bool hermitian() const override { return false; } /** GCR is for any linear system */ @@ -1227,7 +1272,12 @@ namespace quda { public: MR(const DiracMatrix &mat, const DiracMatrix &matSloppy, SolverParam ¶m); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); /** @return Return the residual vector from the prior solve @@ -1299,7 +1349,12 @@ namespace quda { SolverParam ¶m); virtual ~CACG(); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); /** @return Return the residual vector from the prior solve @@ -1333,7 +1388,12 @@ namespace quda { CACGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); /** @return Return the residual vector from the prior solve @@ -1367,7 +1427,12 @@ namespace quda { CACGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); /** @return Return the residual vector from the prior solve @@ -1423,7 +1488,12 @@ namespace quda { SolverParam ¶m); virtual ~CAGCR(); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); /** @return Return the residual vector from the prior solve @@ -1452,7 +1522,12 @@ namespace quda { public: SD(const DiracMatrix &mat, SolverParam ¶m); - void operator()(ColorSpinorField &out, ColorSpinorField &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); /** @return Return the residual from the prior solve @@ -1482,7 +1557,7 @@ namespace quda { virtual ~PreconditionedSolver() { delete solver; } - void operator()(ColorSpinorField &x, ColorSpinorField &b) override + void operator()(cvector_ref &x, cvector_ref &b) override { pushOutputPrefix(prefix); @@ -1685,11 +1760,16 @@ namespace quda { void RestartVT(const double beta, const double rho); void UpdateVm(ColorSpinorField &res, double beta, double sqrtr2); // EigCG solver: - int eigCGsolve(ColorSpinorField &out, ColorSpinorField &in); + int eigCGsolve(ColorSpinorField &out, const ColorSpinorField &in); // InitCG solver: - int initCGsolve(ColorSpinorField &out, ColorSpinorField &in); + int initCGsolve(ColorSpinorField &out, const ColorSpinorField &in); // Incremental eigCG solver (for eigcg and initcg calls) - void operator()(ColorSpinorField &out, ColorSpinorField &in); + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); virtual bool hermitian() const final { return true; } // EigCG is only for Hermitian systems @@ -1726,7 +1806,12 @@ namespace quda { virtual ~GMResDR(); //GMRES-DR solver - void operator()(ColorSpinorField &out, ColorSpinorField &in); + void operator()(cvector_ref &out, cvector_ref &in) override + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); // //GMRESDR method void RunDeflatedCycles (ColorSpinorField *out, ColorSpinorField *in, const double tol_threshold); diff --git a/include/multigrid.h b/include/multigrid.h index 35253bd0a6..e29816f821 100644 --- a/include/multigrid.h +++ b/include/multigrid.h @@ -440,7 +440,12 @@ namespace quda { @param out The solution vector @param in The residual vector (or equivalently the right hand side vector) */ - void operator()(ColorSpinorField &out, ColorSpinorField &in); + void operator()(cvector_ref &out, cvector_ref &in) + { + for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); + } + + void operator()(ColorSpinorField &out, const ColorSpinorField &in); /** @brief Load the null space vectors in from file diff --git a/lib/inv_bicgstab_quda.cpp b/lib/inv_bicgstab_quda.cpp index 7762ad1c10..b1468c1b7d 100644 --- a/lib/inv_bicgstab_quda.cpp +++ b/lib/inv_bicgstab_quda.cpp @@ -61,7 +61,7 @@ namespace quda { return updateR; } - void BiCGstab::operator()(ColorSpinorField &x, ColorSpinorField &b) + void BiCGstab::operator()(ColorSpinorField &x, const ColorSpinorField &b) { create(x, b); @@ -138,7 +138,7 @@ namespace quda { r_sloppy = r.create_alias(); if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - r0 = b.create_alias(); + r0 = const_cast(b).create_alias(); } else { ColorSpinorParam csParam(r); csParam.create = QUDA_NULL_FIELD_CREATE; diff --git a/lib/inv_bicgstabl_quda.cpp b/lib/inv_bicgstabl_quda.cpp index cf5c248297..9d02e87ca6 100644 --- a/lib/inv_bicgstabl_quda.cpp +++ b/lib/inv_bicgstabl_quda.cpp @@ -434,7 +434,7 @@ namespace quda { } } - void BiCGstabL::operator()(ColorSpinorField &x, ColorSpinorField &b) + void BiCGstabL::operator()(ColorSpinorField &x, const ColorSpinorField &b) { // BiCGstab-l is based on the algorithm outlined in // BICGSTAB(L) FOR LINEAR EQUATIONS INVOLVING UNSYMMETRIC MATRICES WITH COMPLEX SPECTRUM diff --git a/lib/inv_ca_cg.cpp b/lib/inv_ca_cg.cpp index 02b05c91bf..b7fdd5eb9a 100644 --- a/lib/inv_ca_cg.cpp +++ b/lib/inv_ca_cg.cpp @@ -59,7 +59,7 @@ namespace quda } // CACGNE: M Mdag y = b is solved; x = Mdag y is returned as solution. - void CACGNE::operator()(ColorSpinorField &x, ColorSpinorField &b) + void CACGNE::operator()(ColorSpinorField &x, const ColorSpinorField &b) { if (param.maxiter == 0 || param.Nsteps == 0) { if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); @@ -138,7 +138,7 @@ namespace quda } // CACGNR: Mdag M x = Mdag b is solved. - void CACGNR::operator()(ColorSpinorField &x, ColorSpinorField &b) + void CACGNR::operator()(ColorSpinorField &x, const ColorSpinorField &b) { if (param.maxiter == 0 || param.Nsteps == 0) { if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); @@ -397,7 +397,7 @@ namespace quda 2. Steepest descent minmization of the residual in this basis 3. Update solution and residual vectors */ - void CACG::operator()(ColorSpinorField &x, ColorSpinorField &b) + void CACG::operator()(ColorSpinorField &x, const ColorSpinorField &b) { if (param.is_preconditioner) commGlobalReductionPush(param.global_reduction); diff --git a/lib/inv_ca_gcr.cpp b/lib/inv_ca_gcr.cpp index 4610df455e..7af84197c7 100644 --- a/lib/inv_ca_gcr.cpp +++ b/lib/inv_ca_gcr.cpp @@ -127,7 +127,7 @@ namespace quda 3. Update solution and residual vectors 4. (Optional) restart if convergence or maxiter not reached */ - void CAGCR::operator()(ColorSpinorField &x, ColorSpinorField &b) + void CAGCR::operator()(ColorSpinorField &x, const ColorSpinorField &b) { const int n_krylov = param.Nkrylov; diff --git a/lib/inv_cg3_quda.cpp b/lib/inv_cg3_quda.cpp index 30b9fff86c..9a33040af8 100644 --- a/lib/inv_cg3_quda.cpp +++ b/lib/inv_cg3_quda.cpp @@ -45,7 +45,7 @@ namespace quda { } // CG3NE: M Mdag y = b is solved; x = Mdag y is returned as solution. - void CG3NE::operator()(ColorSpinorField &x, ColorSpinorField &b) + void CG3NE::operator()(ColorSpinorField &x, const ColorSpinorField &b) { if (param.maxiter == 0 || param.Nsteps == 0) { if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); @@ -123,7 +123,7 @@ namespace quda { } // CG3NR: Mdag M x = Mdag b is solved. - void CG3NR::operator()(ColorSpinorField &x, ColorSpinorField &b) + void CG3NR::operator()(ColorSpinorField &x, const ColorSpinorField &b) { if (param.maxiter == 0 || param.Nsteps == 0) { if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); @@ -198,7 +198,7 @@ namespace quda { return r; } - void CG3::operator()(ColorSpinorField &x, ColorSpinorField &b) + void CG3::operator()(ColorSpinorField &x, const ColorSpinorField &b) { getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index 0951419774..dcb1106259 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -85,7 +85,7 @@ namespace quda { } // CGNE: M Mdag y = b is solved; x = Mdag y is returned as solution. - void CGNE::operator()(ColorSpinorField &x, ColorSpinorField &b) + void CGNE::operator()(ColorSpinorField &x, const ColorSpinorField &b) { if (param.maxiter == 0 || param.Nsteps == 0) { if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); @@ -164,7 +164,7 @@ namespace quda { } // CGNR: Mdag M x = Mdag b is solved. - void CGNR::operator()(ColorSpinorField &x, ColorSpinorField &b) + void CGNR::operator()(ColorSpinorField &x, const ColorSpinorField &b) { if (param.maxiter == 0 || param.Nsteps == 0) { if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); diff --git a/lib/inv_eigcg_quda.cpp b/lib/inv_eigcg_quda.cpp index ad0747ace4..9961f8ef8e 100644 --- a/lib/inv_eigcg_quda.cpp +++ b/lib/inv_eigcg_quda.cpp @@ -335,7 +335,8 @@ namespace quda { /* * This is a solo precision solver. */ - int IncEigCG::eigCGsolve(ColorSpinorField &x, ColorSpinorField &b) { + int IncEigCG::eigCGsolve(ColorSpinorField &x, const ColorSpinorField &b) + { int k=0; @@ -358,7 +359,7 @@ namespace quda { eigcg_args = new EigCGArgs(param.m, param.n_ev); // need only deflation meta structure csParam.create = QUDA_COPY_FIELD_CREATE; - csParam.field = &b; + csParam.field = &const_cast(b); rp = ColorSpinorField::Create(csParam); csParam.create = QUDA_ZERO_FIELD_CREATE; yp = ColorSpinorField::Create(csParam); @@ -525,7 +526,8 @@ namespace quda { return k; } - int IncEigCG::initCGsolve(ColorSpinorField &x, ColorSpinorField &b) { + int IncEigCG::initCGsolve(ColorSpinorField &x, const ColorSpinorField &b) + { int k = 0; //Start init CG iterations: deflated_solver *defl_p = static_cast(param.deflation_op); @@ -586,7 +588,7 @@ namespace quda { return k; } - void IncEigCG::operator()(ColorSpinorField &out, ColorSpinorField &in) + void IncEigCG::operator()(ColorSpinorField &out, const ColorSpinorField &in) { if(param.rhs_idx == 0) max_eigcg_cycles = param.eigcg_max_restarts; diff --git a/lib/inv_gcr_quda.cpp b/lib/inv_gcr_quda.cpp index 45db2e5536..6066f1d14c 100644 --- a/lib/inv_gcr_quda.cpp +++ b/lib/inv_gcr_quda.cpp @@ -175,7 +175,16 @@ namespace quda { } } - void GCR::operator()(ColorSpinorField &x, ColorSpinorField &b) + ColorSpinorField &GCR::get_residual() + { + if (!init) errorQuda("No residual vector present"); + if (param.compute_true_res) + return r; + else + return K ? r_sloppy : p[k_break]; + } + + void GCR::operator()(ColorSpinorField &x, const ColorSpinorField &b) { if (n_krylov == 0) { // Krylov space is zero-dimensional so return doing no work @@ -290,7 +299,7 @@ namespace quda { getProfile().TPSTART(QUDA_PROFILE_COMPUTE); int k = 0; - int k_break = 0; + k_break = 0; PrintStats("GCR", total_iter+k, r2, b2, heavy_quark_res); while ( !convergence(r2, heavy_quark_res, stop, param.tol_hq) && total_iter < param.maxiter) { @@ -394,10 +403,6 @@ namespace quda { param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x,r).z); else param.true_res_hq = 0.0; - //if (param.preserve_source == QUDA_PRESERVE_SOURCE_NO) blas::copy(b, r); - } else { - // reuse this when we add the get_residual method to GCR - if (0) blas::copy(b, K ? r_sloppy : p[k_break]); } param.iter += total_iter; diff --git a/lib/inv_gmresdr_quda.cpp b/lib/inv_gmresdr_quda.cpp index 8d2b323146..6853203dab 100644 --- a/lib/inv_gmresdr_quda.cpp +++ b/lib/inv_gmresdr_quda.cpp @@ -379,7 +379,7 @@ namespace quda { return (j - start_idx); } - void GMResDR::operator()(ColorSpinorField &x, ColorSpinorField &b) + void GMResDR::operator()(ColorSpinorField &x, const ColorSpinorField &b) { getProfile().TPSTART(QUDA_PROFILE_INIT); diff --git a/lib/inv_mr_quda.cpp b/lib/inv_mr_quda.cpp index df8f0f2dbc..b50c349c8f 100644 --- a/lib/inv_mr_quda.cpp +++ b/lib/inv_mr_quda.cpp @@ -53,7 +53,7 @@ namespace quda return r; } - void MR::operator()(ColorSpinorField &x, ColorSpinorField &b) + void MR::operator()(ColorSpinorField &x, const ColorSpinorField &b) { if (param.maxiter == 0 || param.Nsteps == 0) { if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); diff --git a/lib/inv_pcg_quda.cpp b/lib/inv_pcg_quda.cpp index 4fb03527ba..4b15b58ac2 100644 --- a/lib/inv_pcg_quda.cpp +++ b/lib/inv_pcg_quda.cpp @@ -100,8 +100,7 @@ namespace quda } } - void PreconCG::solve_and_collect(ColorSpinorField &x, ColorSpinorField &b, - cvector_ref &v_r, + void PreconCG::solve_and_collect(ColorSpinorField &x, const ColorSpinorField &b, cvector_ref &v_r, int collect_miniter, double collect_tol) { if (K) K->train_param(*this, b); diff --git a/lib/inv_sd_quda.cpp b/lib/inv_sd_quda.cpp index a715f89fc0..484b06c7f2 100644 --- a/lib/inv_sd_quda.cpp +++ b/lib/inv_sd_quda.cpp @@ -30,7 +30,7 @@ namespace quda { return r; } - void SD::operator()(ColorSpinorField &x, ColorSpinorField &b) + void SD::operator()(ColorSpinorField &x, const ColorSpinorField &b) { commGlobalReductionPush(param.global_reduction); diff --git a/lib/multigrid.cpp b/lib/multigrid.cpp index a327717e84..e612127dfe 100644 --- a/lib/multigrid.cpp +++ b/lib/multigrid.cpp @@ -1130,7 +1130,7 @@ namespace quda popLevel(); } - void MG::operator()(ColorSpinorField &x, ColorSpinorField &b) + void MG::operator()(ColorSpinorField &x, const ColorSpinorField &b) { pushOutputPrefix(prefix); @@ -1194,12 +1194,13 @@ namespace quda false; // FIXME this is currently borked if inner solver is preconditioned - ColorSpinorField &residual = !presmoother ? b : - use_solver_residual ? presmoother->get_residual() : - b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? r : - r.Even(); + const ColorSpinorField &residual = !presmoother ? b : + use_solver_residual ? presmoother->get_residual() : + b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? r : + r.Even(); if (!use_solver_residual && presmoother) { + auto &residual = b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? r : r.Even(); (*param.matResidual)(residual, x); axpby(1.0, b, -1.0, residual); } From fda46692da2139b9f462f529a38c0435b4133fe4 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 26 Jun 2024 06:36:14 -0700 Subject: [PATCH 004/103] Optimize DiracWilson: vectorize the prepare/reconstruct functions --- include/dirac_quda.h | 2 -- lib/dirac_wilson.cpp | 19 +++++-------------- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/include/dirac_quda.h b/include/dirac_quda.h index f68b235fde..de9c9e4dd8 100644 --- a/include/dirac_quda.h +++ b/include/dirac_quda.h @@ -469,7 +469,6 @@ namespace quda { DiracWilson(const DiracWilson &dirac); DiracWilson(const DiracParam ¶m, const int nDims); // to correctly adjust face for DW and non-deg twisted mass - virtual ~DiracWilson(); DiracWilson& operator=(const DiracWilson &dirac); virtual void Dslash(cvector_ref &out, cvector_ref &in, @@ -518,7 +517,6 @@ namespace quda { public: DiracWilsonPC(const DiracParam ¶m); DiracWilsonPC(const DiracWilsonPC &dirac); - virtual ~DiracWilsonPC(); DiracWilsonPC& operator=(const DiracWilsonPC &dirac); void M(cvector_ref &out, cvector_ref &in) const; diff --git a/lib/dirac_wilson.cpp b/lib/dirac_wilson.cpp index 30565cfba3..3c64ff2222 100644 --- a/lib/dirac_wilson.cpp +++ b/lib/dirac_wilson.cpp @@ -12,8 +12,6 @@ namespace quda { // hack (for DW and TM operators) DiracWilson::DiracWilson(const DiracParam ¶m, const int) : Dirac(param) { } - DiracWilson::~DiracWilson() { } - DiracWilson& DiracWilson::operator=(const DiracWilson &dirac) { if (&dirac != this) { Dirac::operator=(dirac); } @@ -105,11 +103,6 @@ namespace quda { DiracWilsonPC::DiracWilsonPC(const DiracWilsonPC &dirac) : DiracWilson(dirac) { } - DiracWilsonPC::~DiracWilsonPC() - { - - } - DiracWilsonPC& DiracWilsonPC::operator=(const DiracWilsonPC &dirac) { if (&dirac != this) { @@ -155,9 +148,9 @@ namespace quda { } // we desire solution to full system + // src = b_e + k D_eo b_o + DslashXpay(x(other_parity), b(other_parity), this_parity, b(this_parity), kappa); for (auto i = 0u; i < b.size(); i++) { - // src = b_e + k D_eo b_o - DslashXpay(x[i][other_parity], b[i][other_parity], this_parity, b[i][this_parity], kappa); src[i] = x[i][other_parity].create_alias(); sol[i] = x[i][this_parity].create_alias(); } @@ -169,11 +162,9 @@ namespace quda { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; // create full solution - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - // x_o = b_o + k D_oe x_e - DslashXpay(x[i][other_parity], x[i][this_parity], other_parity, b[i][other_parity], kappa); - } + checkFullSpinor(x, b); + // x_o = b_o + k D_oe x_e + DslashXpay(x(other_parity), x(this_parity), other_parity, b(other_parity), kappa); } } // namespace quda From 44fb98a1759b1f7f9291acdb97d98cc70ed9548e Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 26 Jun 2024 06:36:35 -0700 Subject: [PATCH 005/103] Small cleanup to block_transpose.in.cu --- lib/block_transpose.in.cu | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/lib/block_transpose.in.cu b/lib/block_transpose.in.cu index 952c6ad084..505613fdb7 100644 --- a/lib/block_transpose.in.cu +++ b/lib/block_transpose.in.cu @@ -37,10 +37,7 @@ namespace quda } else { strcat(aux, ",b2v"); } - strcat(aux, ",n_rhs="); - char rhs_str[8]; - i32toa(rhs_str, B.size()); - strcat(aux, rhs_str); + setRHSstring(aux, B.size()); resizeStep(1); apply(device::get_default_stream()); } @@ -138,7 +135,7 @@ namespace quda if constexpr (sizeof...(N) > 0) { launch_span_nColor(V, B, nVecs); } else { - errorQuda("nColor = %d not instantiated\n", V.Ncolor()); + errorQuda("nColor = %d not instantiated", V.Ncolor()); } } } @@ -184,7 +181,7 @@ namespace quda } else if (V.Precision() == QUDA_SINGLE_PRECISION && B[0].Precision() == QUDA_SINGLE_PRECISION) { if constexpr (is_enabled(QUDA_SINGLE_PRECISION)) block_transpose(V, B); } else { - errorQuda("Unsupported precision combination V=%d B=%d\n", V.Precision(), B[0].Precision()); + errorQuda("Unsupported precision combination V=%d B=%d", V.Precision(), B[0].Precision()); } } else { errorQuda("Multigrid has not been built"); From fdd40fb09336b6990eecf55e7b1bd4ba4cd51df6 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 26 Jun 2024 07:06:19 -0700 Subject: [PATCH 006/103] Add new parameter: QudaMultigridParam::n_vec_batch which is the batch size to use when generating null-space vectors. This is the parameter used to enable MRHS null space generation. Updated the null-space generation to work on vectors of this width --- include/multigrid.h | 5 ++++ include/quda.h | 3 +++ lib/multigrid.cpp | 37 +++++++++++++++++++---------- tests/utils/command_line_params.cpp | 3 +++ tests/utils/command_line_params.h | 1 + tests/utils/set_params.cpp | 2 ++ 6 files changed, 38 insertions(+), 13 deletions(-) diff --git a/include/multigrid.h b/include/multigrid.h index e29816f821..aa500f6d56 100644 --- a/include/multigrid.h +++ b/include/multigrid.h @@ -99,6 +99,9 @@ namespace quda { /** Number of vectors used to define coarse space */ int Nvec; + /** Batch size when computing null space vectors */ + int n_vec_batch; + /** Number of times to apply Gram-Schmidt within a block */ int NblockOrtho; @@ -180,6 +183,7 @@ namespace quda { Nlevel(param.n_level), spinBlockSize(param.spin_block_size[level]), Nvec(param.n_vec[level]), + n_vec_batch(param.n_vec_batch[level]), NblockOrtho(param.n_block_ortho[level]), blockOrthoTwoPass(param.block_ortho_two_pass[level]), B(B), @@ -216,6 +220,7 @@ namespace quda { Nlevel(param.Nlevel), spinBlockSize(param.mg_global.spin_block_size[level]), Nvec(param.mg_global.n_vec[level]), + n_vec_batch(param.mg_global.n_vec_batch[level]), NblockOrtho(param.mg_global.n_block_ortho[level]), blockOrthoTwoPass(param.mg_global.block_ortho_two_pass[level]), coarse(param.coarse), diff --git a/include/quda.h b/include/quda.h index ed974dbe9f..6f720d19b6 100644 --- a/include/quda.h +++ b/include/quda.h @@ -655,6 +655,9 @@ extern "C" { /** Inverter to use in the setup phase */ QudaInverterType setup_inv_type[QUDA_MAX_MG_LEVEL]; + /** Solver batch size to use in the setup phase */ + int n_vec_batch[QUDA_MAX_MG_LEVEL]; + /** Number of setup iterations */ int num_setup_iter[QUDA_MAX_MG_LEVEL]; diff --git a/lib/multigrid.cpp b/lib/multigrid.cpp index e612127dfe..5d7b9db979 100644 --- a/lib/multigrid.cpp +++ b/lib/multigrid.cpp @@ -1195,7 +1195,7 @@ namespace quda // FIXME this is currently borked if inner solver is preconditioned const ColorSpinorField &residual = !presmoother ? b : - use_solver_residual ? presmoother->get_residual() : + use_solver_residual ? presmoother->get_residual()[0] : b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? r : r.Even(); @@ -1361,8 +1361,9 @@ namespace quda csParam.gammaBasis = B[0].Nspin() == 1 ? QUDA_DEGRAND_ROSSI_GAMMA_BASIS : QUDA_UKQCD_GAMMA_BASIS; // degrand-rossi required for staggered csParam.create = QUDA_ZERO_FIELD_CREATE; - ColorSpinorField b(csParam); - ColorSpinorField x(csParam); + std::vector b, x; + resize(b, param.n_vec_batch, csParam); + resize(x, param.n_vec_batch, csParam); csParam.create = QUDA_NULL_FIELD_CREATE; @@ -1429,25 +1430,35 @@ namespace quda } // launch solver for each source - for (auto i = 0u; i < B.size(); i++) { - if (param.mg_global.setup_type == QUDA_TEST_VECTOR_SETUP) { // DDalphaAMG test vector idea - b = B[i]; // inverting against the vector - zero(x); // with zero initial guess + if (B.size() % param.n_vec_batch != 0) errorQuda("Bad batch size %d", param.n_vec_batch); + for (auto i = 0u; i < B.size(); i += param.n_vec_batch) { + if (param.mg_global.setup_type + == QUDA_TEST_VECTOR_SETUP) { // DDalphaAMG test vector idea solving against the vector + copy({b.begin(), b.begin() + param.n_vec_batch}, {B.begin() + i, B.begin() + i + param.n_vec_batch}); + zero(x); // with zero initial guess } else { - x = B[i]; + copy({x.begin(), x.begin() + param.n_vec_batch}, {B.begin() + i, B.begin() + i + param.n_vec_batch}); zero(b); } - logQuda(QUDA_VERBOSE, "Initial guess = %g\n", norm2(x)); - logQuda(QUDA_VERBOSE, "Initial rhs = %g\n", norm2(b)); + if (getVerbosity() >= QUDA_VERBOSE) { + auto nrm2 = norm2(x); + auto b2 = norm2(b); + for (auto j = 0; j < param.n_vec_batch; j++) + printfQuda("%d Initial guess = %g, Initial rhs = %g\n", i + j, nrm2[j], b2[j]); + } - ColorSpinorField out, in; + std::vector out(param.n_vec_batch), in(param.n_vec_batch); diracSmoother->prepare(out, in, x, b, QUDA_MAT_SOLUTION); (*solve)(out, in); diracSmoother->reconstruct(x, b, QUDA_MAT_SOLUTION); - logQuda(QUDA_VERBOSE, "Solution = %g\n", norm2(x)); - B[i] = x; + if (getVerbosity() >= QUDA_VERBOSE) { + auto nrm2 = norm2(x); + for (auto j = 0; j < param.n_vec_batch; j++) printfQuda("%d Solution = %g\n", i + j, nrm2[j]); + } + + copy({B.begin() + i, B.begin() + i + param.n_vec_batch}, {x.begin(), x.begin() + param.n_vec_batch}); } // global orthonormalization of the generated null-space vectors diff --git a/tests/utils/command_line_params.cpp b/tests/utils/command_line_params.cpp index 204c336ac2..fe937da3e5 100644 --- a/tests/utils/command_line_params.cpp +++ b/tests/utils/command_line_params.cpp @@ -66,6 +66,7 @@ int pipeline = 0; int solution_accumulator_pipeline = 0; int test_type = 0; quda::mgarray nvec = {}; +quda::mgarray nvec_batch = {}; quda::mgarray mg_vec_infile; quda::mgarray mg_vec_outfile; quda::mgarray mg_vec_partfile = {}; @@ -994,6 +995,8 @@ void add_multigrid_option_group(std::shared_ptr quda_app) "The number of pre-smoother applications to do at a given multigrid level (default 2)"); quda_app->add_mgoption(opgroup, "--mg-nvec", nvec, CLI::PositiveNumber, "Number of null-space vectors to define the multigrid transfer operator on a given level"); + quda_app->add_mgoption(opgroup, "--mg-nvec-batch", nvec_batch, CLI::PositiveNumber, + "Batch size to use when computing the null-space vectors to define the multigrid transfer operator on a given level"); opgroup->add_option("--mg-oblique-proj-check", oblique_proj_check, "Measure how well the null vector subspace adjusts the low eigenmode subspace (default false)"); opgroup->add_option("--mg-omega", omega, diff --git a/tests/utils/command_line_params.h b/tests/utils/command_line_params.h index 5b038e4cb0..ecd20ae065 100644 --- a/tests/utils/command_line_params.h +++ b/tests/utils/command_line_params.h @@ -324,6 +324,7 @@ extern int pipeline; extern int solution_accumulator_pipeline; extern int test_type; extern quda::mgarray nvec; +extern quda::mgarray nvec_batch; extern quda::mgarray mg_vec_infile; extern quda::mgarray mg_vec_outfile; extern quda::mgarray mg_vec_partfile; diff --git a/tests/utils/set_params.cpp b/tests/utils/set_params.cpp index 3ba1900cf0..a9e6fe2dc0 100644 --- a/tests/utils/set_params.cpp +++ b/tests/utils/set_params.cpp @@ -471,6 +471,7 @@ void setMultigridParam(QudaMultigridParam &mg_param) mg_param.spin_block_size[i] = 1; mg_param.n_vec[i] = nvec[i] == 0 ? 24 : nvec[i]; // default to 24 vectors if not set + mg_param.n_vec_batch[i] = nvec_batch[i] == 0 ? 1 : nvec_batch[i]; // default to batch size 1 if not set mg_param.n_block_ortho[i] = n_block_ortho[i]; // number of times to Gram-Schmidt mg_param.block_ortho_two_pass[i] = block_ortho_two_pass[i] ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; // whether to use a two-pass block ortho @@ -1078,6 +1079,7 @@ void setStaggeredMultigridParam(QudaMultigridParam &mg_param) mg_param.spin_block_size[i] = 1; mg_param.n_vec[i] = nvec[i] == 0 ? 64 : nvec[i]; // default to 64 vectors if not set + mg_param.n_vec_batch[i] = nvec_batch[i] == 0 ? 1 : nvec_batch[i]; // default to batch size 1 if not set mg_param.n_block_ortho[i] = n_block_ortho[i]; // number of times to Gram-Schmidt mg_param.block_ortho_two_pass[i] = block_ortho_two_pass[i] ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; // whether to use a two-pass block ortho From fa64adf4483185ffff27cbbf6690bf20525d7150 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 26 Jun 2024 14:39:46 -0700 Subject: [PATCH 007/103] Vectorize DiracCoarsePC prepare/reconstruct --- include/dirac_quda.h | 3 -- lib/dirac_coarse.cpp | 86 +++++++++++++++++++------------------------- 2 files changed, 36 insertions(+), 53 deletions(-) diff --git a/include/dirac_quda.h b/include/dirac_quda.h index de9c9e4dd8..55da680c99 100644 --- a/include/dirac_quda.h +++ b/include/dirac_quda.h @@ -1855,7 +1855,6 @@ namespace quda { @param[in] param Parameters defining this operator */ DiracCoarse(const DiracCoarse &dirac, const DiracParam ¶m); - virtual ~DiracCoarse(); virtual bool isCoarse() const { return true; } @@ -2000,8 +1999,6 @@ namespace quda { */ DiracCoarsePC(const DiracCoarse &dirac, const DiracParam ¶m); - virtual ~DiracCoarsePC(); - /** @brief Apply preconditioned Dslash out = (D * in) @param[out] out Output field diff --git a/lib/dirac_coarse.cpp b/lib/dirac_coarse.cpp index c138b95158..b8c646e84b 100644 --- a/lib/dirac_coarse.cpp +++ b/lib/dirac_coarse.cpp @@ -121,10 +121,6 @@ namespace quda { { } - DiracCoarse::~DiracCoarse() - { - } - void DiracCoarse::createY(bool gpu, bool mapped) const { int ndim = transfer->Vectors().Ndim(); @@ -513,8 +509,6 @@ namespace quda { /* do nothing */ } - DiracCoarsePC::~DiracCoarsePC() { } - void DiracCoarsePC::Dslash(cvector_ref &out, cvector_ref &in, QudaParity parity) const { @@ -592,61 +586,53 @@ namespace quda { return; } - auto tmp = getFieldTmp(b[0].Even()); - - // we desire solution to full system + auto tmp = getFieldTmp(x.Even()); for (auto i = 0u; i < b.size(); i++) { + src[i] = x[i][other_parity].create_alias(); + sol[i] = x[i][this_parity].create_alias(); + } - if (symmetric) { - // src = A_ee^-1 (b_e - D_eo A_oo^-1 b_o) - src[i] = x[i][other_parity].create_alias(); + // we desire solution to full system + if (symmetric) { + // src = A_ee^-1 (b_e - D_eo A_oo^-1 b_o) #if 0 - CloverInv(src[i], b[other_parity], other_parity); - DiracCoarse::Dslash(tmp, src[i], this_parity); - blas::xpay(b[i][this_parity], -1.0, tmp); - CloverInv(src[i], tmp, this_parity); + CloverInv(src, b(other_parity), other_parity); + DiracCoarse::Dslash(tmp, src, this_parity); + blas::xpay(b(this_parity), -1.0, tmp); + CloverInv(src, tmp, this_parity); +#else + // src = A_ee^{-1} b_e - (A_ee^{-1} D_eo) A_oo^{-1} b_o + CloverInv(src, b(other_parity), other_parity); + Dslash(tmp, src, this_parity); + CloverInv(src, b(this_parity), this_parity); + blas::axpy(-1.0, tmp, src); #endif - // src = A_ee^{-1} b_e - (A_ee^{-1} D_eo) A_oo^{-1} b_o - CloverInv(src[i], b[i][other_parity], other_parity); - Dslash(tmp, src[i], this_parity); - CloverInv(src[i], b[i][this_parity], this_parity); - blas::axpy(-1.0, tmp, src[i]); - - sol[i] = x[i][this_parity].create_alias(); - } else { - // src = b_e - D_eo A_oo^-1 b_o - src[i] = x[i][other_parity].create_alias(); - CloverInv(tmp, b[i][other_parity], other_parity); - DiracCoarse::Dslash(src[i], tmp, this_parity); - blas::xpay(b[i][this_parity], -1.0, src[i]); - sol[i] = x[i][this_parity].create_alias(); - } + } else { + // src = b_e - D_eo A_oo^-1 b_o + CloverInv(tmp, b(other_parity), other_parity); + DiracCoarse::Dslash(src, tmp, this_parity); + blas::xpay(b(this_parity), -1.0, src); } } void DiracCoarsePC::reconstruct(cvector_ref &x, cvector_ref &b, const QudaSolutionType solType) const { - if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - return; - } - - auto tmp = getFieldTmp(b[0].Even()); - - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - -#if 0 - // x_o = A_oo^-1 (b_o - D_oe x_e) - DiracCoarse::Dslash(tmp, x.Even(), QUDA_ODD_PARITY); - blas::xpay(b.Odd(), -1.0, tmp); - CloverInv(x.Odd(), tmp, QUDA_ODD_PARITY); + if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; + + checkFullSpinor(x, b); + auto tmp = getFieldTmp(x.Even()); +#if 1 + // x_o = A_oo^-1 (b_o - D_oe x_e) + DiracCoarse::Dslash(tmp, x.Even(), QUDA_ODD_PARITY); + blas::xpay(b.Odd(), -1.0, tmp); + CloverInv(x.Odd(), tmp, QUDA_ODD_PARITY); +#else + // x_o = A_oo^{-1} b_o - (A_oo^{-1} D_oe) x_e + Dslash(tmp, x(this_parity), other_parity); + CloverInv(x(other_parity), b(other_parity), other_parity); + blas::axpy(-1.0, tmp, x(other_parity)); #endif - // x_o = A_oo^{-1} b_o - (A_oo^{-1} D_oe) x_e - Dslash(tmp, x[i][this_parity], other_parity); - CloverInv(x[i][other_parity], b[i][other_parity], other_parity); - blas::axpy(-1.0, tmp, x[i][other_parity]); - } } //Make the coarse operator one level down. For the preconditioned From fef58e8277a9ae209e6da1ac9216b56b3f7b33c0 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 26 Jun 2024 14:41:52 -0700 Subject: [PATCH 008/103] Ensure we don't enable large arg support for pre Volta architecture --- lib/targets/cuda/target_cuda.cmake | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/lib/targets/cuda/target_cuda.cmake b/lib/targets/cuda/target_cuda.cmake index 4b2ce8e89d..306e666195 100644 --- a/lib/targets/cuda/target_cuda.cmake +++ b/lib/targets/cuda/target_cuda.cmake @@ -82,19 +82,6 @@ endif() # CUDA specific QUDA options include(CMakeDependentOption) -# large arg support requires CUDA 12.1 -cmake_dependent_option(QUDA_LARGE_KERNEL_ARG "enable large kernel arg support" ON "${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.1" OFF ) -message(STATUS "Large kernel arguments supported: ${QUDA_LARGE_KERNEL_ARG}") -mark_as_advanced(QUDA_LARGE_KERNEL_ARG) - -# Set the maximum multi-RHS per kernel -if(QUDA_LARGE_KERNEL_ARG) - set(QUDA_MAX_MULTI_RHS "64" CACHE STRING "maximum number of simultaneous RHS in a kernel") -else() - set(QUDA_MAX_MULTI_RHS "16" CACHE STRING "maximum number of simultaneous RHS in a kernel") -endif() -message(STATUS "Max number of rhs per kernel: ${QUDA_MAX_MULTI_RHS}") - option(QUDA_VERBOSE_BUILD "display kernel register usage" OFF) option(QUDA_JITIFY "build QUDA using Jitify" OFF) option(QUDA_DOWNLOAD_NVSHMEM "Download NVSHMEM" OFF) @@ -139,6 +126,21 @@ endif() set_target_properties(quda PROPERTIES CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES}) +# large arg support requires CUDA 12.1 and Volta+ +cmake_dependent_option(QUDA_LARGE_KERNEL_ARG "enable large kernel arg support" ON + "${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.1 AND ${QUDA_COMPUTE_CAPABILITY} GREATER_EQUAL 70" + OFF) +message(STATUS "Large kernel arguments supported: ${QUDA_LARGE_KERNEL_ARG}") +mark_as_advanced(QUDA_LARGE_KERNEL_ARG) + +# Set the maximum multi-RHS per kernel +if(QUDA_LARGE_KERNEL_ARG) + set(QUDA_MAX_MULTI_RHS "64" CACHE STRING "maximum number of simultaneous RHS in a kernel") +else() + set(QUDA_MAX_MULTI_RHS "16" CACHE STRING "maximum number of simultaneous RHS in a kernel") +endif() +message(STATUS "Max number of rhs per kernel: ${QUDA_MAX_MULTI_RHS}") + # QUDA_HASH for tunecache set(HASH cpu_arch=${CPU_ARCH},gpu_arch=${QUDA_GPU_ARCH},cuda_version=${CMAKE_CUDA_COMPILER_VERSION}) set(GITVERSION "${PROJECT_VERSION}-${GITVERSION}-${QUDA_GPU_ARCH}") From 8b8cd99a897e31cef13c80519f68a6050d249c7c Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Thu, 27 Jun 2024 12:32:24 -0700 Subject: [PATCH 009/103] Create vector variants of create_alias --- include/color_spinor_field.h | 30 ++++++++++++++++++++++++++++++ lib/color_spinor_util.in.cu | 14 ++++++++++++++ lib/dirac_coarse.cpp | 20 +++++++------------- lib/dirac_wilson.cpp | 20 +++++++------------- 4 files changed, 58 insertions(+), 26 deletions(-) diff --git a/include/color_spinor_field.h b/include/color_spinor_field.h index bbea1b7842..3023a4e248 100644 --- a/include/color_spinor_field.h +++ b/include/color_spinor_field.h @@ -943,6 +943,36 @@ namespace quda void resize(std::vector &v, size_t new_size, QudaFieldCreate create, const ColorSpinorField &src = ColorSpinorField()); + /** + @brief Create a vector of fields that aliases another vector of + fields' storage. The alias field can use a different precision + than this field, though it cannot be greater. This + functionality is useful for the case where we have multiple + temporaries in different precisions, but do not need them + simultaneously. Use this functionality with caution. + @param[out] alias The vector of aliased fields + @param[in] v The vector of fields to alias + @param[in] param Parameters for the alias field + */ + void create_alias(cvector_ref &alias, cvector_ref &v, + const ColorSpinorParam ¶m = ColorSpinorParam()); + + /** + @brief Create a vector of fields that aliases another vector of + fields' storage. The alias field can use a different precision + than this field, though it cannot be greater. This functionality + is useful for the case where we have multiple temporaries in + different precisions, but do not need them simultaneously. This + variant is used with std::vector as opposed to vector_ref, and + allows for correct resizing. Use this functionality with + caution. + @param[out] alias The vector of aliased fields + @param[in] v The vector of fields to alias + @param[in] param Parameters for the alias field + */ + void create_alias(std::vector &alias, cvector_ref &v, + const ColorSpinorParam ¶m = ColorSpinorParam()); + void copyGenericColorSpinor(ColorSpinorField &dst, const ColorSpinorField &src, QudaFieldLocation location, void *Dst = nullptr, const void *Src = nullptr); diff --git a/lib/color_spinor_util.in.cu b/lib/color_spinor_util.in.cu index 0b9355d4d1..ef370eda18 100644 --- a/lib/color_spinor_util.in.cu +++ b/lib/color_spinor_util.in.cu @@ -423,4 +423,18 @@ namespace quda { resize(v, new_size, param); } + void create_alias(cvector_ref &alias, cvector_ref &v, + const ColorSpinorParam ¶m) + { + if (alias.size() != v.size()) errorQuda("sets differ in size (%lu != %lu)", alias.size(), v.size()); + for (auto i = 0u; i < v.size(); i++) alias[i] = const_cast(v[i]).create_alias(param); + } + + void create_alias(std::vector &alias, cvector_ref &v, + const ColorSpinorParam ¶m) + { + alias.resize(v.size()); + create_alias(cvector_ref(alias), v, param); + } + } // namespace quda diff --git a/lib/dirac_coarse.cpp b/lib/dirac_coarse.cpp index b8c646e84b..0136b29ac3 100644 --- a/lib/dirac_coarse.cpp +++ b/lib/dirac_coarse.cpp @@ -451,10 +451,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracCoarse::reconstruct(cvector_ref &, cvector_ref &, @@ -579,19 +577,15 @@ namespace quda { { // we desire solution to preconditioned system if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } - auto tmp = getFieldTmp(x.Even()); - for (auto i = 0u; i < b.size(); i++) { - src[i] = x[i][other_parity].create_alias(); - sol[i] = x[i][this_parity].create_alias(); - } + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + auto tmp = getFieldTmp(x.Even()); // we desire solution to full system if (symmetric) { // src = A_ee^-1 (b_e - D_eo A_oo^-1 b_o) diff --git a/lib/dirac_wilson.cpp b/lib/dirac_wilson.cpp index 3c64ff2222..a6a5264b8e 100644 --- a/lib/dirac_wilson.cpp +++ b/lib/dirac_wilson.cpp @@ -73,10 +73,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracWilson::reconstruct(cvector_ref &, cvector_ref &, @@ -138,22 +136,18 @@ namespace quda { cvector_ref &x, cvector_ref &b, const QudaSolutionType solType) const { - if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } // we desire solution to full system // src = b_e + k D_eo b_o DslashXpay(x(other_parity), b(other_parity), this_parity, b(this_parity), kappa); - for (auto i = 0u; i < b.size(); i++) { - src[i] = x[i][other_parity].create_alias(); - sol[i] = x[i][this_parity].create_alias(); - } + + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); } void DiracWilsonPC::reconstruct(cvector_ref &x, cvector_ref &b, From ee6fd265c6dfdc477e0313da4e5584aca2f35f02 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 28 Jun 2024 22:01:51 -0700 Subject: [PATCH 010/103] Add some more scalar wrappers: this facilitates us making the vector -> scalar cast operator explicit instead of implicit --- include/blas_quda.h | 83 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 77 insertions(+), 6 deletions(-) diff --git a/include/blas_quda.h b/include/blas_quda.h index b7da601e38..889347e67f 100644 --- a/include/blas_quda.h +++ b/include/blas_quda.h @@ -293,7 +293,7 @@ namespace quda { inline array max_deviation(const ColorSpinorField &x, const ColorSpinorField &y) { - return max_deviation(cvector_ref(x), cvector_ref(y)); + return max_deviation(cvector_ref(x), cvector_ref(y))[0]; } /** @@ -302,13 +302,15 @@ namespace quda { */ cvector norm1(cvector_ref &x); + inline double norm1(const ColorSpinorField &x) { return norm1(cvector_ref {x})[0]; } + /** @brief Compute the L2 norm (||x||^2) of a field @param[in] x The field we are reducing */ cvector norm2(cvector_ref &x); - inline double norm2(const ColorSpinorField &x) { return norm2(cvector_ref {x}); } + inline double norm2(const ColorSpinorField &x) { return norm2(cvector_ref {x})[0]; } /** @brief Compute y += a * x and then (x, y) @@ -319,6 +321,11 @@ namespace quda { cvector axpyReDot(cvector &a, cvector_ref &x, cvector_ref &y); + inline double axpyReDot(double a, const ColorSpinorField &x, ColorSpinorField &y) + { + return axpyReDot(cvector(a), x, y)[0]; + } + /** @brief Compute the real-valued inner product (x, y) @param[in] x input vector @@ -328,7 +335,7 @@ namespace quda { inline double reDotProduct(const ColorSpinorField &x, const ColorSpinorField &y) { - return reDotProduct(cvector_ref {x}, cvector_ref {y}); + return reDotProduct(cvector_ref {x}, cvector_ref {y})[0]; } /** @@ -342,6 +349,12 @@ namespace quda { cvector axpbyzNorm(cvector &a, cvector_ref &x, cvector &b, cvector_ref &y, cvector_ref &z); + inline double axpbyzNorm(double a, const ColorSpinorField &x, double b, const ColorSpinorField &y, + ColorSpinorField &z) + { + return axpbyzNorm(cvector(a), x, cvector(b), y, z)[0]; + } + /** @brief Compute y += a * x and then ||y||^2 @param[in] a scalar multiplier @@ -354,6 +367,11 @@ namespace quda { return axpbyzNorm(a, x, 1.0, y, y); } + inline double axpyNorm(double a, const ColorSpinorField &x, ColorSpinorField &y) + { + return axpyNorm(a, cvector_ref {x}, cvector_ref {y})[0]; + } + /** @brief Compute the complex-valued inner product (x, y) @param[in] x input vector @@ -363,7 +381,7 @@ namespace quda { inline Complex cDotProduct(const ColorSpinorField &x, const ColorSpinorField &y) { - return cDotProduct(cvector_ref {x}, cvector_ref {y}); + return cDotProduct(cvector_ref {x}, cvector_ref {y})[0]; } /** @@ -373,6 +391,11 @@ namespace quda { */ cvector cDotProductNormAB(cvector_ref &x, cvector_ref &y); + inline double4 cDotProductNormAB(const ColorSpinorField &x, const ColorSpinorField &y) + { + return cDotProductNormAB(cvector_ref {x}, cvector_ref {y})[0]; + } + /** @brief Return complex-valued inner product (x,y) and ||x||^2 @param[in] x input vector @@ -387,6 +410,11 @@ namespace quda { return a; } + inline double3 cDotProductNormA(const ColorSpinorField &x, const ColorSpinorField &y) + { + return cDotProductNormA(cvector_ref {x}, cvector_ref {y})[0]; + } + /** @brief Return complex-valued inner product (x,y) and ||y||^2 @param[in] x input vector @@ -401,6 +429,11 @@ namespace quda { return a; } + inline double3 cDotProductNormB(const ColorSpinorField &x, const ColorSpinorField &y) + { + return cDotProductNormB(cvector_ref {x}, cvector_ref {y})[0]; + } + /** @brief Apply the operation z += a * x + b * y, y -= b * w, compute complex-valued inner product (u, y) and ||y||^2 @@ -418,6 +451,13 @@ namespace quda { cvector_ref &w, cvector_ref &u); + inline double3 caxpbypzYmbwcDotProductUYNormY(const Complex &a, const ColorSpinorField &x, const Complex &b, + ColorSpinorField &y, ColorSpinorField &z, const ColorSpinorField &w, + const ColorSpinorField &u) + { + return caxpbypzYmbwcDotProductUYNormY(cvector(a), x, cvector(b), y, z, w, u)[0]; + } + /** @brief Compute y = a * x + b * y and then ||y||^2 @param[in] a scalar multiplier @@ -440,6 +480,11 @@ namespace quda { return caxpbyNorm(a, x, 1.0, y); } + inline double caxpyNorm(const Complex &a, const ColorSpinorField &x, ColorSpinorField &y) + { + return caxpyNorm(a, cvector_ref {x}, cvector_ref {y})[0]; + } + /** @brief Compute y -= x and then ||y||^2 @param[in] x input vector @@ -450,6 +495,11 @@ namespace quda { return caxpbyNorm(1.0, x, -1.0, y); } + inline double xmyNorm(const ColorSpinorField &x, ColorSpinorField &y) + { + return xmyNorm(cvector_ref {x}, cvector_ref {y})[0]; + } + /** @brief Compute z = a * b * x + y, x = a * x, and then ||z||^2 @param[in] a scalar multiplier @@ -461,6 +511,12 @@ namespace quda { cvector cabxpyzAxNorm(cvector &a, cvector &b, cvector_ref &x, cvector_ref &y, cvector_ref &z); + inline double cabxpyzAxNorm(double a, const Complex &b, ColorSpinorField &x, const ColorSpinorField &y, + ColorSpinorField &z) + { + return cabxpyzAxNorm(cvector(a), cvector(b), x, y, z)[0]; + } + /** @brief Compute y += a * x and the resulting complex-valued inner product (z, y) @param[in] a scalar multiplier @@ -471,6 +527,11 @@ namespace quda { cvector caxpyDotzy(cvector &a, cvector_ref &x, cvector_ref &y, cvector_ref &z); + inline Complex caxpyDotzy(const Complex &a, const ColorSpinorField &x, ColorSpinorField &y, const ColorSpinorField &z) + { + return caxpyDotzy(cvector(a), x, y, z)[0]; + } + /** @brief Compute y += a * x and then compute ||y||^2 and real-valued inner product (y_out, y_out-y_in) @@ -481,6 +542,11 @@ namespace quda { cvector axpyCGNorm(cvector &a, cvector_ref &x, cvector_ref &y); + inline double2 axpyCGNorm(double a, const ColorSpinorField &x, ColorSpinorField &y) + { + return axpyCGNorm(cvector(a), x, y)[0]; + } + /** @brief Computes ||x||^2, ||r||^2 and the MILC/FNAL heavy quark residual norm @@ -492,7 +558,7 @@ namespace quda { inline double3 HeavyQuarkResidualNorm(const ColorSpinorField &x, const ColorSpinorField &r) { - return HeavyQuarkResidualNorm(cvector_ref(x), cvector_ref(r)); + return HeavyQuarkResidualNorm(cvector_ref(x), cvector_ref(r))[0]; } /** @@ -510,7 +576,7 @@ namespace quda { const ColorSpinorField &r) { return xpyHeavyQuarkResidualNorm(cvector_ref(x), cvector_ref(y), - cvector_ref(r)); + cvector_ref(r))[0]; } /** @@ -522,6 +588,11 @@ namespace quda { cvector tripleCGReduction(cvector_ref &x, cvector_ref &y, cvector_ref &z); + inline double3 tripleCGReduction(const ColorSpinorField &x, const ColorSpinorField &y, const ColorSpinorField &z) + { + return tripleCGReduction(cvector_ref(x), y, z)[0]; + } + /** @brief Computes ||x||^2, ||y||^2, the real-valued inner product (y, z), and ||z||^2 @param[in] x input vector From ac23c73cf14684c34a81cceb7aacfd876947a64f Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 28 Jun 2024 22:02:09 -0700 Subject: [PATCH 011/103] Supress annoying warning with Eigen --- include/eigen_helper.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/include/eigen_helper.h b/include/eigen_helper.h index f4d91d74b5..b3eb57baf0 100644 --- a/include/eigen_helper.h +++ b/include/eigen_helper.h @@ -10,8 +10,15 @@ #endif #include + +// hide annoying warning +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" + #include #include #include +#pragma GCC diagnostic pop + using namespace Eigen; From 2b0763b0c230587eb23221dd3f86d5ea71917aec Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 28 Jun 2024 22:04:13 -0700 Subject: [PATCH 012/103] Add default copy/move constructors/assignment operator for XUpdateBatch - will be needed for batched CG --- include/invert_x_update.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/include/invert_x_update.h b/include/invert_x_update.h index 6726275bde..64ec7caa33 100644 --- a/include/invert_x_update.h +++ b/include/invert_x_update.h @@ -43,6 +43,11 @@ namespace quda } } + XUpdateBatch(const XUpdateBatch &other) = default; + XUpdateBatch(XUpdateBatch &&other) = default; + XUpdateBatch &operator=(const XUpdateBatch &other) = default; + XUpdateBatch &operator=(XUpdateBatch &&other) = default; + /** @brief use the vectors currently stored and add to the given output field @param x the output field to add to From 70a94df2133fc19f4d98aceac48358994986f46f Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 28 Jun 2024 22:12:24 -0700 Subject: [PATCH 013/103] Add some useful overloads to vector class to facilitate writing batching solvers --- include/reference_wrapper_helper.h | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/include/reference_wrapper_helper.h b/include/reference_wrapper_helper.h index a9eccf3648..9890503f1b 100644 --- a/include/reference_wrapper_helper.h +++ b/include/reference_wrapper_helper.h @@ -535,8 +535,38 @@ namespace quda if (std::vector::size() != 1) errorQuda("Cast to scalar failed since size = %lu", std::vector::size()); return std::vector::operator[](0); } + + bool operator<(const vector &v) const + { + for (auto i = 0u; i < v.size(); i++) + if (this->operator[](i) >= v[i]) return false; + return true; + } + + bool operator>(const vector &v) const + { + for (auto i = 0u; i < v.size(); i++) + if (this->operator[](i) <= v[i]) return false; + return true; + } + + vector operator-() const + { + vector negative(*this); + for (auto &v : negative) v = -v; + return negative; + } + + vector operator*(const T &u) const + { + vector multiplied(*this); + for (auto &v : multiplied) v *= u; + return multiplied; + } }; + template vector operator*(const T &a, const vector &b) { return b * a; } + template using cvector = const vector; } // namespace quda From f792a33b5d96e844da911a3fe1931ddf42adf923 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 28 Jun 2024 22:39:30 -0700 Subject: [PATCH 014/103] Add explicit casting to double in anticipation of making the cast operator explict --- lib/interface_quda.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index b7c046f955..370e99f5ae 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -4046,8 +4046,8 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) cg(x[i], b); } - solverParam.true_res_offset[i] = solverParam.true_res; - solverParam.true_res_hq_offset[i] = solverParam.true_res_hq; + solverParam.true_res_offset[i] = static_cast(solverParam.true_res); + solverParam.true_res_hq_offset[i] = static_cast(solverParam.true_res_hq); solverParam.updateInvertParam(*param,i); if (param->dslash_type == QUDA_ASQTAD_DSLASH || From 57ba15ef38abb016b791c0362bb876e20dda058c Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Sun, 30 Jun 2024 23:08:49 -0700 Subject: [PATCH 015/103] First pass at enabling MRHS for CG, MR and SD solvers. To better find issues at compile time, the scalar -> vector cast operator is now explicit. Apply necessary changes to ensure that stuff doesn't break with this change --- include/blas_quda.h | 30 +- include/invert_quda.h | 118 ++++---- include/reference_wrapper_helper.h | 3 +- lib/inv_bicgstab_quda.cpp | 2 +- lib/inv_ca_cg.cpp | 6 +- lib/inv_ca_gcr.cpp | 2 +- lib/inv_cg3_quda.cpp | 8 +- lib/inv_cg_quda.cpp | 430 +++++++++++++++++------------ lib/inv_gcr_quda.cpp | 2 +- lib/inv_mr_quda.cpp | 72 ++--- lib/inv_sd_quda.cpp | 56 ++-- lib/solver.cpp | 158 ++++++----- tests/dslash_test_utils.h | 6 +- 13 files changed, 523 insertions(+), 370 deletions(-) diff --git a/include/blas_quda.h b/include/blas_quda.h index 889347e67f..907f23945b 100644 --- a/include/blas_quda.h +++ b/include/blas_quda.h @@ -323,7 +323,7 @@ namespace quda { inline double axpyReDot(double a, const ColorSpinorField &x, ColorSpinorField &y) { - return axpyReDot(cvector(a), x, y)[0]; + return axpyReDot(cvector(a), cvector_ref(x), y)[0]; } /** @@ -352,7 +352,7 @@ namespace quda { inline double axpbyzNorm(double a, const ColorSpinorField &x, double b, const ColorSpinorField &y, ColorSpinorField &z) { - return axpbyzNorm(cvector(a), x, cvector(b), y, z)[0]; + return axpbyzNorm(cvector(a), cvector_ref(x), cvector(b), y, z)[0]; } /** @@ -455,7 +455,8 @@ namespace quda { ColorSpinorField &y, ColorSpinorField &z, const ColorSpinorField &w, const ColorSpinorField &u) { - return caxpbypzYmbwcDotProductUYNormY(cvector(a), x, cvector(b), y, z, w, u)[0]; + return caxpbypzYmbwcDotProductUYNormY(cvector(a), cvector_ref(x), b, y, z, w, + u)[0]; } /** @@ -514,7 +515,7 @@ namespace quda { inline double cabxpyzAxNorm(double a, const Complex &b, ColorSpinorField &x, const ColorSpinorField &y, ColorSpinorField &z) { - return cabxpyzAxNorm(cvector(a), cvector(b), x, y, z)[0]; + return cabxpyzAxNorm(cvector(a), cvector(b), cvector_ref(x), y, z)[0]; } /** @@ -529,7 +530,7 @@ namespace quda { inline Complex caxpyDotzy(const Complex &a, const ColorSpinorField &x, ColorSpinorField &y, const ColorSpinorField &z) { - return caxpyDotzy(cvector(a), x, y, z)[0]; + return caxpyDotzy(cvector(a), cvector_ref(x), y, z)[0]; } /** @@ -544,7 +545,7 @@ namespace quda { inline double2 axpyCGNorm(double a, const ColorSpinorField &x, ColorSpinorField &y) { - return axpyCGNorm(cvector(a), x, y)[0]; + return axpyCGNorm(cvector(a), cvector_ref(x), y)[0]; } /** @@ -602,6 +603,11 @@ namespace quda { cvector quadrupleCGReduction(cvector_ref &x, cvector_ref &y, cvector_ref &z); + inline double4 quadrupleCGReduction(const ColorSpinorField &x, const ColorSpinorField &y, const ColorSpinorField &z) + { + return quadrupleCGReduction(cvector_ref(x), y, z)[0]; + } + /** @brief Computes z = x, w = y, x += a * y, y -= a * v and ||y||^2 @param[in] a scalar multiplier @@ -615,6 +621,12 @@ namespace quda { cvector_ref &y, cvector_ref &z, cvector_ref &w, cvector_ref &v); + inline double quadrupleCG3InitNorm(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, + ColorSpinorField &w, const ColorSpinorField &v) + { + return quadrupleCG3InitNorm(cvector(a), cvector_ref(x), y, z, w, v)[0]; + } + /** @brief Computes x = b * (x + a * y) + ( 1 - b) * z, y = b * (y + a * v) + (1 - b) * w, z = x_in, w = y_in, and @@ -631,6 +643,12 @@ namespace quda { cvector_ref &y, cvector_ref &z, cvector_ref &w, cvector_ref &v); + inline double quadrupleCG3UpdateNorm(double a, double b, ColorSpinorField &x, ColorSpinorField &y, + ColorSpinorField &z, ColorSpinorField &w, const ColorSpinorField &v) + { + return quadrupleCG3UpdateNorm(cvector(a), b, cvector_ref(x), y, z, w, v)[0]; + } + namespace block { diff --git a/include/invert_quda.h b/include/invert_quda.h index 19d02a2635..4520712644 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -113,7 +113,7 @@ namespace quda { int max_hq_res_increase = 0; /**< This parameter determines how many total heavy-quark residual - restarts we tolerate before terminating the solver, i.e., how long + restarts we tolerate before terminating the solver, i.e., how long do we want to keep trying to converge */ int max_hq_res_restart_total = 0; @@ -139,10 +139,10 @@ namespace quda { bool sloppy_converge = false; /**< Actual L2 residual norm achieved in solver */ - double true_res = 0.0; + vector true_res; /**< Actual heavy quark residual norm achieved in solver */ - double true_res_hq = 0.0; + vector true_res_hq; /**< Maximum number of iterations in the linear solver */ int maxiter = 0; @@ -374,8 +374,8 @@ namespace quda { @param param the QudaInvertParam to be updated */ void updateInvertParam(QudaInvertParam ¶m, int offset=-1) { - param.true_res = true_res; - param.true_res_hq = true_res_hq; + param.true_res = static_cast(true_res); + param.true_res_hq = static_cast(true_res_hq); param.iter += iter; if (offset >= 0) { param.true_res_offset[offset] = true_res_offset[offset]; @@ -439,11 +439,10 @@ namespace quda { /** @return Return the residual vector from the prior solve */ - virtual ColorSpinorField &get_residual() + virtual cvector_ref get_residual() { errorQuda("Not implemented"); - static ColorSpinorField dummy; - return dummy; + return cvector_ref(); } /** @@ -490,7 +489,7 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); /** @brief Solver factory @@ -544,7 +543,12 @@ namespace quda { @param[in] residual_type The type of residual we want to solve for @return L2 stopping condition */ - static double stopping(double tol, double b2, QudaResidualType residual_type); + static vector stopping(double tol, cvector &b2, QudaResidualType residual_type); + + static inline double stopping(double tol, double b2, QudaResidualType residual_type) + { + return stopping(tol, cvector(b2), residual_type)[0]; + } /** @briefTest for solver convergence @@ -554,7 +558,7 @@ namespace quda { @param[in] hq_tol Solver heavy-quark tolerance @return Whether converged */ - bool convergence(double r2, double hq2, double r2_tol, double hq_tol); + bool convergence(cvector &r2, cvector &hq2, cvector &r2_tol, cvector &hq_tol); /** @brief Test for HQ solver convergence -- ignore L2 residual @@ -564,7 +568,7 @@ namespace quda { @param[in[ hq_tol Solver heavy-quark tolerance @return Whether converged */ - bool convergenceHQ(double r2, double hq2, double r2_tol, double hq_tol); + bool convergenceHQ(cvector &hq2, cvector &hq_tol); /** @brief Test for L2 solver convergence -- ignore HQ residual @@ -573,7 +577,7 @@ namespace quda { @param[in] r2_tol Solver L2 tolerance @param[in] hq_tol Solver heavy-quark tolerance */ - bool convergenceL2(double r2, double hq2, double r2_tol, double hq_tol); + bool convergenceL2(cvector &r2, cvector &r2_tol); /** @brief Prints out the running statistics of the solver @@ -583,7 +587,7 @@ namespace quda { @param[in] r2 L2 norm squared of the residual @param[in] hq2 Heavy quark residual */ - void PrintStats(const char *name, int k, double r2, double b2, double hq2); + void PrintStats(const char *name, int k, cvector &r2, cvector &b2, cvector &hq2 = {}); /** @brief Prints out the summary of the solver convergence @@ -596,7 +600,8 @@ namespace quda { @param[in] r2_tol Solver L2 tolerance @param[in] hq_tol Solver heavy-quark tolerance */ - void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol); + void PrintSummary(const char *name, int k, cvector &r2, cvector &b2, cvector &r2_tol, + cvector &hq_tol = {}); /** @brief Returns the epsilon tolerance for a given precision, by default returns @@ -702,13 +707,13 @@ namespace quda { class CG : public Solver { private: - ColorSpinorField y; - ColorSpinorField r; - ColorSpinorField rnew; - ColorSpinorField p; - ColorSpinorField Ap; - ColorSpinorField rSloppy; - ColorSpinorField xSloppy; + std::vector y; + std::vector r; + std::vector rnew; + std::vector p; + std::vector Ap; + std::vector r_sloppy; + std::vector x_sloppy; bool init = false; /** @@ -716,7 +721,7 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: CG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, @@ -730,8 +735,9 @@ namespace quda { */ void operator()(cvector_ref &out, cvector_ref &in) override { - for (auto i = 0u; i < in.size(); i++) (*this)(out[i], in[i], ColorSpinorField(), 0.0); - }; + std::vector tmp(in.size()); + operator()(out, in, tmp, vector(in.size(), 0.0)); + } /** * @brief Solve re-using an initial Krylov space defined by an initial r2_old_init and search direction p_init. @@ -741,15 +747,15 @@ namespace quda { * @param p_init Initial-search direction. * @param r2_old_init [description] */ - void operator()(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &p_init, - double r2_old_init); + void operator()(cvector_ref &out, cvector_ref &in, + cvector_ref &p_init, cvector &r2_old_init); void blocksolve(ColorSpinorField &out, ColorSpinorField &in) override; /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const override { return true; } /** CG is only for Hermitian systems */ @@ -761,7 +767,7 @@ namespace quda { * @param out Solution-vector. * @param in Right-hand side. */ - void hqsolve(ColorSpinorField &out, const ColorSpinorField &in); + void hqsolve(cvector_ref &out, cvector_ref &in); }; class CGNE : public CG @@ -797,7 +803,7 @@ namespace quda { /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const final { return false; } /** CGNE is for any system */ @@ -836,7 +842,7 @@ namespace quda { /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const final { return false; } /** CGNR is for any system */ @@ -877,7 +883,7 @@ namespace quda { /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const override { return true; } /** CG is only for Hermitian systems */ @@ -915,7 +921,7 @@ namespace quda { /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const final { return false; } /** CG3NE is for any system */ @@ -952,7 +958,7 @@ namespace quda { /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual(); + cvector_ref get_residual(); virtual bool hermitian() const final { return false; } /** CG3NR is for any system */ @@ -1063,7 +1069,7 @@ namespace quda { /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const override { return false; } /** BiCGStab is for any linear system */ @@ -1246,7 +1252,7 @@ namespace quda { /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const override { return false; } /** GCR is for any linear system */ @@ -1256,10 +1262,10 @@ namespace quda { class MR : public Solver { private: - ColorSpinorField r; - ColorSpinorField r_sloppy; - ColorSpinorField Ar; - ColorSpinorField x_sloppy; + std::vector r; + std::vector r_sloppy; + std::vector Ar; + std::vector x_sloppy; bool init = false; /** @@ -1267,22 +1273,23 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: MR(const DiracMatrix &mat, const DiracMatrix &matSloppy, SolverParam ¶m); - void operator()(cvector_ref &out, cvector_ref &in) override + void operator()(cvector_ref &out, cvector_ref &in) override; +#if 0 { for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); } void operator()(ColorSpinorField &out, const ColorSpinorField &in); - +#endif /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const override { return false; } /** MR is for any linear system */ @@ -1359,7 +1366,7 @@ namespace quda { /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const override { return true; } /** CG is only for Hermitian systems */ @@ -1398,7 +1405,7 @@ namespace quda { /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const final { return false; } /** CA-CGNE is for any linear system */ @@ -1437,7 +1444,7 @@ namespace quda { /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const final { return false; } /** CA-CGNR is for any linear system */ @@ -1498,7 +1505,7 @@ namespace quda { /** @return Return the residual vector from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const override { return false; } /** GCR is for any linear system */ @@ -1508,8 +1515,8 @@ namespace quda { // Steepest descent solver used as a preconditioner class SD : public Solver { private: - ColorSpinorField Ar; - ColorSpinorField r; + std::vector Ar; + std::vector r; bool init = false; /** @@ -1517,22 +1524,17 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: SD(const DiracMatrix &mat, SolverParam ¶m); - void operator()(cvector_ref &out, cvector_ref &in) override - { - for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); - } - - void operator()(ColorSpinorField &out, const ColorSpinorField &in); + void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual from the prior solve */ - ColorSpinorField &get_residual() override; + cvector_ref get_residual() override; virtual bool hermitian() const override { return false; } /** SD is for any linear system */ diff --git a/include/reference_wrapper_helper.h b/include/reference_wrapper_helper.h index 9890503f1b..50dc066701 100644 --- a/include/reference_wrapper_helper.h +++ b/include/reference_wrapper_helper.h @@ -530,7 +530,7 @@ namespace quda /** @brief Cast to scalar. Only works if the vector size is 1. */ - operator T() const + explicit operator T() const { if (std::vector::size() != 1) errorQuda("Cast to scalar failed since size = %lu", std::vector::size()); return std::vector::operator[](0); @@ -563,6 +563,7 @@ namespace quda for (auto &v : multiplied) v *= u; return multiplied; } + }; template vector operator*(const T &a, const vector &b) { return b * a; } diff --git a/lib/inv_bicgstab_quda.cpp b/lib/inv_bicgstab_quda.cpp index b1468c1b7d..1c2e8187b5 100644 --- a/lib/inv_bicgstab_quda.cpp +++ b/lib/inv_bicgstab_quda.cpp @@ -40,7 +40,7 @@ namespace quda { } // init } - ColorSpinorField &BiCGstab::get_residual() + cvector_ref BiCGstab::get_residual() { if (!init) errorQuda("No residual vector present"); if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); diff --git a/lib/inv_ca_cg.cpp b/lib/inv_ca_cg.cpp index b7fdd5eb9a..66f3800afb 100644 --- a/lib/inv_ca_cg.cpp +++ b/lib/inv_ca_cg.cpp @@ -50,7 +50,7 @@ namespace quda } } - ColorSpinorField &CACGNE::get_residual() + cvector_ref CACGNE::get_residual() { if (!init) errorQuda("No residual vector present"); if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); @@ -130,7 +130,7 @@ namespace quda } } - ColorSpinorField &CACGNR::get_residual() + cvector_ref CACGNR::get_residual() { if (!init) errorQuda("No residual vector present"); if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); @@ -373,7 +373,7 @@ namespace quda return updateR; } - ColorSpinorField &CACG::get_residual() + cvector_ref CACG::get_residual() { if (!init) errorQuda("No residual vector present"); if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); diff --git a/lib/inv_ca_gcr.cpp b/lib/inv_ca_gcr.cpp index 7af84197c7..abd68aad3d 100644 --- a/lib/inv_ca_gcr.cpp +++ b/lib/inv_ca_gcr.cpp @@ -113,7 +113,7 @@ namespace quda } } - ColorSpinorField &CAGCR::get_residual() + cvector_ref CAGCR::get_residual() { if (!init) errorQuda("No residual vector present"); if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); diff --git a/lib/inv_cg3_quda.cpp b/lib/inv_cg3_quda.cpp index 9a33040af8..eb2564ba7d 100644 --- a/lib/inv_cg3_quda.cpp +++ b/lib/inv_cg3_quda.cpp @@ -36,7 +36,7 @@ namespace quda { } } - ColorSpinorField &CG3NE::get_residual() + cvector_ref CG3NE::get_residual() { if (!init) errorQuda("No residual vector present"); if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); @@ -115,7 +115,7 @@ namespace quda { } } - ColorSpinorField &CG3NR::get_residual() + cvector_ref CG3NR::get_residual() { if (!init) errorQuda("No residual vector present"); if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); @@ -192,7 +192,7 @@ namespace quda { /** @return Return the residual vector from the prior solve */ - ColorSpinorField &CG3::get_residual() + cvector_ref CG3::get_residual() { if (!init) errorQuda("No residual vector present"); return r; @@ -320,7 +320,7 @@ namespace quda { if (convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol) update = true; // For heavy-quark inversion force a reliable update if we continue after - if ( use_heavy_quark_res and L2breakdown and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq) and param.delta >= param.tol ) { + if ( use_heavy_quark_res and L2breakdown and convergenceHQ(heavy_quark_res, param.tol_hq) and param.delta >= param.tol ) { update = true; } diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index dcb1106259..31911b16cd 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -27,30 +27,37 @@ namespace quda { CG::~CG() { destroyDeflationSpace(); } - void CG::create(ColorSpinorField &x, const ColorSpinorField &b) + void CG::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); if (!init) { getProfile().TPSTART(QUDA_PROFILE_INIT); - ColorSpinorParam csParam(x); - csParam.create = QUDA_NULL_FIELD_CREATE; - r = ColorSpinorField(csParam); - y = ColorSpinorField(csParam); + resize(r, b.size(), QUDA_NULL_FIELD_CREATE, b[0]); + resize(y, b.size(), QUDA_NULL_FIELD_CREATE, b[0]); // sloppy fields + ColorSpinorParam csParam(x[0]); + csParam.create = QUDA_NULL_FIELD_CREATE; csParam.setPrecision(param.precision_sloppy); - p = ColorSpinorField(csParam); - Ap = ColorSpinorField(csParam); + resize(p, b.size(), csParam); + resize(Ap, b.size(), csParam); - rSloppy = (r.Precision() != param.precision_sloppy) ? ColorSpinorField(csParam) : r.create_alias(); + if (param.precision != param.precision_sloppy) { + resize(r_sloppy, b.size(), csParam); + } else { + create_alias(r_sloppy, r); + } param.use_sloppy_partial_accumulator = false; // hard-code precise accumulation - xSloppy = (param.use_sloppy_partial_accumulator == true) ? ColorSpinorField(csParam) : x.create_alias(); + if (param.use_sloppy_partial_accumulator) resize(x_sloppy, b.size(), csParam); init = true; getProfile().TPSTOP(QUDA_PROFILE_INIT); } + + // need to reset x_sloppy every solve + if (!param.use_sloppy_partial_accumulator) create_alias(x_sloppy, x); } CGNE::CGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, @@ -76,7 +83,7 @@ namespace quda { } } - ColorSpinorField &CGNE::get_residual() + cvector_ref CGNE::get_residual() { if (!init) errorQuda("No residual vector present"); if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); @@ -156,7 +163,7 @@ namespace quda { } } - ColorSpinorField &CGNR::get_residual() + cvector_ref CGNR::get_residual() { if (!init) errorQuda("No residual vector present"); if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); @@ -206,7 +213,8 @@ namespace quda { } } - void CG::operator()(ColorSpinorField &x, const ColorSpinorField &b, const ColorSpinorField &p_init, double r2_old_init) + void CG::operator()(cvector_ref &x, cvector_ref &b, + cvector_ref &p_init, cvector &r2_old_init) { if (param.is_preconditioner) commGlobalReductionPush(param.global_reduction); @@ -244,23 +252,29 @@ namespace quda { if (!param.is_preconditioner) getProfile().TPSTART(QUDA_PROFILE_INIT); - double b2 = blas::norm2(b); + vector b2 = blas::norm2(b); // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0 && param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_INIT); - printfQuda("Warning: inverting on zero-field source\n"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; - return; + if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { + bool zero_src = true; + for (auto i = 0u; i < b.size(); i++) { + if (b2[i] == 0) { + warningQuda("inverting on zero-field source"); + x[i] = b[i]; + param.true_res[i] = 0.0; + param.true_res_hq[i] = 0.0; + } else { + zero_src = false; + } + } + if (zero_src) return; } create(x, b); if (param.deflate) { // Construct the eigensolver and deflation space if requested. - constructDeflationSpace(b, matEig); + constructDeflationSpace(b[0], matEig); if (deflate_compute) { // compute the deflation space. if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_INIT); @@ -277,27 +291,28 @@ namespace quda { const double u = precisionEpsilon(param.precision_sloppy); const double uhigh = precisionEpsilon(); // solver precision - double Anorm = 0; - double beta = 0; + double Anorm = 0.0; + vector beta(b.size(), 0.0); // for alternative reliable updates if (advanced_feature && alternative_reliable) { // estimate norm for reliable updates - mat(r, b); - Anorm = sqrt(blas::norm2(r)/b2); + mat(r[0], b[0]); + Anorm = sqrt(blas::norm2(r[0]) / b2[0]); } // compute initial residual - double r2 = 0.0; + vector r2(b2.size(), 0.0); if (advanced_feature && param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { // Compute r = b - A * x mat(r, x); r2 = blas::xmyNorm(b, r); - if (b2 == 0) b2 = r2; + for (auto i = 0u; i < b.size(); i++) + if (b2[i] == 0) b2[i] = r2[i]; // y contains the original guess. blas::copy(y, x); } else { - if (&r != &b) blas::copy(r, b); + blas::copy(r, b); r2 = b2; blas::zero(y); } @@ -310,20 +325,25 @@ namespace quda { } blas::zero(x); - if (&x != &xSloppy) blas::zero(xSloppy); - blas::copy(rSloppy,r); + if (param.use_sloppy_partial_accumulator) blas::zero(x_sloppy); + blas::copy(r_sloppy, r); - ColorSpinorParam csParam(rSloppy); + ColorSpinorParam csParam(r_sloppy[0]); csParam.create = QUDA_NULL_FIELD_CREATE; - XUpdateBatch x_update_batch(Np, !p_init.empty() ? p_init : rSloppy, csParam); - - double r2_old = 0.0; - if (r2_old_init != 0.0 and !p_init.empty()) { - r2_old = r2_old_init; - Complex rp = blas::cDotProduct(rSloppy, x_update_batch.get_current_field()) / (r2); - blas::caxpy(-rp, rSloppy, x_update_batch.get_current_field()); - beta = r2 / r2_old; - blas::xpayz(rSloppy, beta, x_update_batch.get_current_field(), x_update_batch.get_current_field()); + std::vector x_update_batch(b.size()); + for (auto i = 0u; i < b.size(); i++) + x_update_batch[i] = XUpdateBatch(Np, !p_init[i].empty() ? p_init[i] : r_sloppy[i], csParam); + + vector r2_old(r2.size(), 0.0); + for (auto i = 0u; i < b.size(); i++) { + if (r2_old_init[i] != 0.0 and !p_init[i].empty()) { + // FIXME vectorize this + r2_old[i] = r2_old_init[i]; + Complex rp = blas::cDotProduct(r_sloppy[i], x_update_batch[i].get_current_field()) / (r2[i]); + blas::caxpy(-rp, r_sloppy[i], x_update_batch[i].get_current_field()); + beta[i] = r2[i] / r2_old[i]; + blas::xpayz(r_sloppy[i], beta[i], x_update_batch[i].get_current_field(), x_update_batch[i].get_current_field()); + } } if (!param.is_preconditioner) { @@ -331,10 +351,9 @@ namespace quda { getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); } - double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver - auto alpha = std::make_unique(Np); - double pAp; + vector pAp(b.size()); if (!param.is_preconditioner) { getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); @@ -345,7 +364,7 @@ namespace quda { PrintStats("CG", k, r2, b2, 0.0); - bool converged = convergenceL2(r2, 0.0, stop, 0.0); + bool converged = convergenceL2(r2, stop); ReliableUpdatesParams ru_params; @@ -359,30 +378,56 @@ namespace quda { ru_params.maxResIncreaseTotal = param.max_res_increase_total; ru_params.use_heavy_quark_res = false; // since we've removed HQ residual support - ReliableUpdates ru(ru_params, r2); + ReliableUpdates ru(ru_params, r2[0]); + + auto get_p = [](std::vector &x_update_batch, bool next = false) { + vector_ref p; + p.reserve(x_update_batch.size()); + for (auto &x : x_update_batch) p.push_back(next ? x.get_next_field() : x.get_current_field()); + return p; + }; + + auto get_alpha = [](std::vector &x_update_batch) { + vector alpha; + alpha.reserve(x_update_batch.size()); + for (auto &x : x_update_batch) alpha.push_back(x.get_current_alpha()); + return alpha; + }; while ( !converged && k < param.maxiter ) { - matSloppy(Ap, x_update_batch.get_current_field()); - double sigma; + auto p = get_p(x_update_batch); + auto p_next = get_p(x_update_batch, true); + matSloppy(Ap, p); + + vector sigma(b.size()); + ; bool breakdown = false; if (advanced_feature && param.pipeline) { - double Ap2; - if(alternative_reliable){ - double4 quadruple = blas::quadrupleCGReduction(rSloppy, Ap, x_update_batch.get_current_field()); - r2 = quadruple.x; - Ap2 = quadruple.y; - pAp = quadruple.z; - ru.update_ppnorm(quadruple.w); + vector Ap2(b.size()); + if (alternative_reliable) { + auto quadruple = blas::quadrupleCGReduction(r_sloppy, Ap, p); + for (auto i = 0u; i < b.size(); i++) { + r2[i] = quadruple[i].x; + Ap2[i] = quadruple[i].y; + pAp[i] = quadruple[i].z; + } + ru.update_ppnorm(quadruple[0].w); // using 0th system for RU } else { - double3 triplet = blas::tripleCGReduction(rSloppy, Ap, x_update_batch.get_current_field()); - r2 = triplet.x; Ap2 = triplet.y; pAp = triplet.z; + auto triplet = blas::tripleCGReduction(r_sloppy, Ap, p); + for (auto i = 0u; i < b.size(); i++) { + r2[i] = triplet[i].x; + Ap2[i] = triplet[i].y; + pAp[i] = triplet[i].z; + } } r2_old = r2; - x_update_batch.get_current_alpha() = r2 / pAp; - sigma = x_update_batch.get_current_alpha() * (x_update_batch.get_current_alpha() * Ap2 - pAp); - if (sigma < 0.0 || ru.steps_since_reliable == 0) { // sigma condition has broken down - r2 = blas::axpyNorm(-x_update_batch.get_current_alpha(), Ap, rSloppy); + for (auto i = 0u; i < b.size(); i++) { + x_update_batch[i].get_current_alpha() = r2[i] / pAp[i]; + sigma[i] = x_update_batch[i].get_current_alpha() * (x_update_batch[i].get_current_alpha() * Ap2[i] - pAp[i]); + } + if (sigma[0] < 0.0 || ru.steps_since_reliable == 0) { // sigma condition has broken down + r2 = blas::axpyNorm(-get_alpha(x_update_batch), Ap, r_sloppy); sigma = r2; breakdown = true; } @@ -393,69 +438,71 @@ namespace quda { // alternative reliable updates, if (advanced_feature && alternative_reliable) { - double3 pAppp = blas::cDotProductNormA(x_update_batch.get_current_field(), Ap); - pAp = pAppp.x; - ru.update_ppnorm(pAppp.z); + auto pAppp = blas::cDotProductNormA(p, Ap); + for (auto i = 0u; i < b.size(); i++) pAp[i] = pAppp[i].x; + ru.update_ppnorm(pAppp[0].z); // using 0th system for RU } else { - pAp = blas::reDotProduct(x_update_batch.get_current_field(), Ap); + pAp = blas::reDotProduct(p, Ap); } - x_update_batch.get_current_alpha() = r2 / pAp; + for (auto i = 0u; i < b.size(); i++) x_update_batch[i].get_current_alpha() = r2[i] / pAp[i]; // here we are deploying the alternative beta computation - double2 cg_norm = blas::axpyCGNorm(-x_update_batch.get_current_alpha(), Ap, rSloppy); - r2 = cg_norm.x; // (r_new, r_new) - sigma = cg_norm.y >= 0.0 ? cg_norm.y : r2; // use r2 if (r_k+1, r_k+1-r_k) breaks + auto cg_norm = blas::axpyCGNorm(-get_alpha(x_update_batch), Ap, r_sloppy); + for (auto i = 0u; i < b.size(); i++) { + r2[i] = cg_norm[i].x; // (r_new, r_new) + sigma[i] = cg_norm[i].y >= 0.0 ? cg_norm[i].y : r2[i]; // use r2 if (r_k+1, r_k+1-r_k) breaks + } } // reliable update conditions - ru.update_rNorm(sqrt(r2)); + ru.update_rNorm(sqrt(r2[0])); if (advanced_feature) { - ru.evaluate(r2_old); + ru.evaluate(r2_old[0]); // force a reliable update if we are within target tolerance (only if doing reliable updates) - if (convergenceL2(r2, 0.0, stop, 0.0) && param.delta >= param.tol) ru.set_updateX(); + if (convergenceL2(r2, stop) && param.delta >= param.tol) ru.set_updateX(); } if (!ru.trigger()) { - beta = sigma / r2_old; // use the alternative beta computation + for (auto i = 0u; i < beta.size(); i++) beta[i] = sigma[i] / r2_old[i]; // use the alternative beta computation if (advanced_feature && param.pipeline && !breakdown) { if (Np == 1) { - blas::tripleCGUpdate(x_update_batch.get_current_alpha(), beta, Ap, xSloppy, rSloppy, - x_update_batch.get_current_field()); + blas::tripleCGUpdate(get_alpha(x_update_batch), beta, Ap, x_sloppy, r_sloppy, p); } else { errorQuda("Not implemented pipelined CG with Np > 1"); } } else { if (Np == 1) { // with Np=1 we just run regular fusion between x and p updates - blas::axpyZpbx(x_update_batch.get_current_alpha(), x_update_batch.get_current_field(), xSloppy, rSloppy, - beta); + blas::axpyZpbx(get_alpha(x_update_batch), p, x_sloppy, r_sloppy, beta); } else { - if (x_update_batch.is_container_full()) { x_update_batch.accumulate_x(xSloppy); } + for (auto i = 0u; i < b.size(); i++) + if (x_update_batch[i].is_container_full()) { x_update_batch[i].accumulate_x(x_sloppy[i]); } // p[(k+1)%Np] = r + beta * p[k%Np] - blas::xpayz(rSloppy, beta, x_update_batch.get_current_field(), x_update_batch.get_next_field()); + blas::xpayz(r_sloppy, beta, p, p_next); } } // alternative reliable updates - if (advanced_feature) { ru.accumulate_norm(x_update_batch.get_current_alpha()); } + if (advanced_feature) { ru.accumulate_norm(x_update_batch[0].get_current_alpha()); } } else { - x_update_batch.accumulate_x(xSloppy); - x_update_batch.reset_next(); - - blas::copy(x, xSloppy); // nop when these pointers alias + for (auto i = 0u; i < b.size(); i++) { + x_update_batch[i].accumulate_x(x_sloppy[i]); + x_update_batch[i].reset_next(); + } + blas::copy(x, x_sloppy); // nop when these pointers alias blas::xpy(x, y); // swap these around? mat(r, y); // here we can use x as tmp r2 = blas::xmyNorm(b, r); - if (param.deflate && sqrt(r2) < ru.maxr_deflate * param.tol_restart) { + if (param.deflate && sqrt(r2[0]) < ru.maxr_deflate * param.tol_restart) { // Deflate and accumulate to solution vector eig_solve->deflate(y, r, evecs, evals, true); @@ -463,28 +510,32 @@ namespace quda { mat(r, y); r2 = blas::xmyNorm(b, r); - ru.update_maxr_deflate(r2); + ru.update_maxr_deflate(r2[0]); } - blas::copy(rSloppy, r); //nop when these pointers alias - blas::zero(xSloppy); + blas::copy(r_sloppy, r); // nop when these pointers alias + blas::zero(x_sloppy); - if (advanced_feature) { ru.update_norm(r2, y); } + if (advanced_feature) { ru.update_norm(r2[0], y[0]); } if (advanced_feature) { // needed as a "dummy parameter" to reliable_break. bool L2breakdown = false; - if (ru.reliable_break(r2, stop, L2breakdown, 0)) { break; } + if (ru.reliable_break(r2[0], stop[0], L2breakdown, 0)) { break; } } // explicitly restore the orthogonality of the gradient vector - Complex rp = blas::cDotProduct(rSloppy, x_update_batch.get_current_field()) / (r2); - blas::caxpy(-rp, rSloppy, x_update_batch.get_current_field()); + auto p = get_p(x_update_batch); + auto p_next = get_p(x_update_batch, true); + + auto rp = blas::cDotProduct(r_sloppy, p); + for (auto i = 0u; i < b.size(); i++) rp[i] /= r2[i]; + blas::caxpy(-rp, r_sloppy, p); - beta = r2 / r2_old; - blas::xpayz(rSloppy, beta, x_update_batch.get_current_field(), x_update_batch.get_next_field()); + for (auto i = 0u; i < beta.size(); i++) beta[i] = r2[i] / r2_old[i]; + blas::xpayz(r_sloppy, beta, p, p_next); - ru.reset(r2); + ru.reset(r2[0]); } breakdown = false; @@ -492,21 +543,23 @@ namespace quda { PrintStats("CG", k, r2, b2, 0.0); // check convergence - converged = convergenceL2(r2, 0.0, stop, 0.0); + converged = convergenceL2(r2, stop); // if we have converged and need to update any trailing solutions - if (converged && ru.steps_since_reliable > 0 && !x_update_batch.is_container_full()) { - x_update_batch.accumulate_x(xSloppy); - } + for (auto i = 0u; i < b.size(); i++) { + if (converged && ru.steps_since_reliable > 0 && !x_update_batch[i].is_container_full()) { + x_update_batch[i].accumulate_x(x_sloppy[i]); + } - if (ru.steps_since_reliable == 0) { - x_update_batch.reset(); - } else { - ++x_update_batch; + if (ru.steps_since_reliable == 0) { + x_update_batch[i].reset(); + } else { + ++x_update_batch[i]; + } } } - blas::copy(x, xSloppy); + blas::copy(x, x_sloppy); blas::xpy(y, x); if (!param.is_preconditioner) { @@ -523,8 +576,12 @@ namespace quda { if (advanced_feature && param.compute_true_res) { // compute the true residuals mat(r, x); - param.true_res = sqrt(blas::xmyNorm(b, r) / b2); - param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + auto true_r2 = blas::xmyNorm(b, r); + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) { + param.true_res[i] = sqrt(true_r2[i] / b2[i]); + param.true_res_hq[i] = sqrt(hq[i].z); + } } PrintSummary("CG", k, r2, b2, stop, 0.0); @@ -534,14 +591,14 @@ namespace quda { if (param.is_preconditioner) commGlobalReductionPop(); } - ColorSpinorField &CG::get_residual() + cvector_ref CG::get_residual() { if (!init) errorQuda("No residual vector present"); return r; } // Separate HQ residual codepath - void CG::hqsolve(ColorSpinorField &x, const ColorSpinorField &b) + void CG::hqsolve(cvector_ref &x, cvector_ref &b) { logQuda(QUDA_VERBOSE, "Performing a HQ CG solve\n"); @@ -557,7 +614,7 @@ namespace quda { getProfile().TPSTART(QUDA_PROFILE_INIT); - double b2 = blas::norm2(b); + vector b2 = blas::norm2(b); // Detect whether this is a pure double solve or not; informs the necessity of some stability checks bool is_pure_double = (param.precision == QUDA_DOUBLE_PRECISION && param.precision_sloppy == QUDA_DOUBLE_PRECISION); @@ -565,13 +622,19 @@ namespace quda { bool heavy_quark_restart = false; // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0 && param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - getProfile().TPSTOP(QUDA_PROFILE_INIT); - printfQuda("Warning: inverting on zero-field source\n"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; - return; + if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { + bool zero_src = true; + for (auto i = 0u; i < b.size(); i++) { + if (b2[i] == 0) { + warningQuda("inverting on zero-field source"); + x[i] = b[i]; + param.true_res[i] = 0.0; + param.true_res_hq[i] = 0.0; + } else { + zero_src = false; + } + } + if (zero_src) return; } create(x, b); @@ -585,36 +648,44 @@ namespace quda { const double hq_res_stall_check = is_pure_double ? 0. : uhigh * uhigh * 1e-60; // compute initial residual - double r2 = 0.0; + vector r2(b.size()); if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { // Compute r = b - A * x mat(r, x); r2 = blas::xmyNorm(b, r); - if (b2 == 0) b2 = r2; + for (auto i = 0u; i < b.size(); i++) + if (b2[i] == 0) b2[i] = r2[i]; // y contains the original guess. blas::copy(y, x); } else { - if (&r != &b) blas::copy(r, b); + blas::copy(r, b); r2 = b2; blas::zero(y); } blas::zero(x); - if (&x != &xSloppy) blas::zero(xSloppy); - blas::copy(rSloppy, r); - blas::copy(p, rSloppy); + if (param.use_sloppy_partial_accumulator) blas::zero(x_sloppy); + blas::copy(r_sloppy, r); + blas::copy(p, r_sloppy); - double r2_old = 0.0; + vector r2_old(b.size(), 0.0); getProfile().TPSTOP(QUDA_PROFILE_INIT); getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); - double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + + auto get_hq_res = [](cvector_ref &x, cvector_ref &r) { + auto hq_nrm = blas::HeavyQuarkResidualNorm(x, r); + vector hq_res(hq_nrm.size()); + for (auto i = 0u; i < hq_nrm.size(); i++) hq_res[i] = sqrt(hq_res[i]); + return hq_res; + }; // compute the initial heavy quark residual - double hq_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + vector hq_res = get_hq_res(x, r); - double alpha, beta, sigma, pAp; + vector alpha(b.size()), beta(b.size()), sigma(b.size()), pAp(b.size()); // Whether or not we also need to compute the L2 norm const bool L2_required = param.residual_type & (QUDA_L2_RELATIVE_RESIDUAL | QUDA_L2_ABSOLUTE_RESIDUAL); @@ -638,13 +709,13 @@ namespace quda { // Trackers for the L2 norm: // rNorm: current iterated |r| // r0Norm: computed |r| at the last reliable update - double rNorm = sqrt(r2); - double r0Norm = rNorm; + auto rNorm = sqrt(r2[0]); + auto r0Norm = rNorm; // If the computed |r| goes above r0Norm between reliable updates, // update this ceiling. This goes into "R" type reliable updates. - double maxrx = L2breakdown ? hq_res : rNorm; - double maxrr = L2breakdown ? hq_res : rNorm; + double maxrx = L2breakdown ? hq_res[0] : rNorm; + double maxrr = L2breakdown ? hq_res[0] : rNorm; // Triggers for explicitly counting residual updates and checking for L2breakdown. // * updateX broadly maps to if the iterated residual has dropped by a factor of delta @@ -665,7 +736,7 @@ namespace quda { // Trackers for the HQ residual // hq0Res: computed HQ residual at the last reliable updated - double hq0Res = hq_res; + auto hq0Res = hq_res; // Counter for the number of times in a row the computed heavy quark residual has // jumped above the previously computed heavy quark residual. @@ -688,20 +759,22 @@ namespace quda { pAp = blas::reDotProduct(p, Ap); - alpha = r2 / pAp; + for (auto i = 0u; i < alpha.size(); i++) alpha[i] = r2[i] / pAp[i]; // here we are deploying the alternative beta computation - double2 cg_norm = blas::axpyCGNorm(-alpha, Ap, rSloppy); - r2 = cg_norm.x; // (r_new, r_new) - sigma = cg_norm.y >= 0.0 ? cg_norm.y : r2; // use r2 if (r_k+1, r_k+1-r_k) breaks - rNorm = sqrt(r2); + auto cg_norm = blas::axpyCGNorm(-alpha, Ap, r_sloppy); + for (auto i = 0u; i < cg_norm.size(); i++) { + r2[i] = cg_norm[i].x; // (r_new, r_new) + sigma[i] = cg_norm[i].y >= 0.0 ? cg_norm[i].y : r2[i]; // use r2 if (r_k+1, r_k+1-r_k) breaks + } + rNorm = sqrt(r2[0]); // If the iterated norm has dropped by more than a factor of delta, trigger // an update. The baseline we check against differs depending on if // we're still checking the L2 norm, or if that has converged/broken down and we're // now looking at the HQ residual. - if (!L2breakdown && (L2_required || convergenceL2(r2, hq_res, stop, param.tol_hq))) { + if (!L2breakdown && (L2_required || convergenceL2(hq_res, param.tol_hq))) { // L2 based reliable update // If the iterated residual norm has gone above the most recent "baseline" norm, @@ -717,42 +790,50 @@ namespace quda { updateR = ((rNorm < param.delta * maxrr && r0Norm <= maxrr) || updateX); } else { // hqresidual based reliable update - if (hq_res > maxrx) maxrx = hq_res; - if (hq_res > maxrr) maxrr = hq_res; + if (hq_res[0] > maxrx) maxrx = hq_res[0]; + if (hq_res[0] > maxrr) maxrr = hq_res[0]; // I'm making the decision to use `param.delta` for the hq_res check because // in some regards it's an L2-esque norm... // Has the iterated heavy quark residual dropped by a factor of delta^2 from the last // computed norm? - updateX = (hq_res < param.delta * param.delta * hq0Res && r0Norm <= maxrx); + updateX = (hq_res[0] < param.delta * param.delta * hq0Res[0] && r0Norm <= maxrx); // Has the iterated heavy quark residual dropped by a factor of delta relative // to the largest the iterated norm has been since the last update? - updateR = ((hq_res < param.delta * param.delta * maxrr && hq0Res <= maxrr) || updateX); + updateR = ((hq_res[0] < param.delta * param.delta * maxrr && hq0Res[0] <= maxrr) || updateX); } // force a reliable update if we are within target tolerance (only if doing reliable updates) if (convergence(r2, hq_res, stop, param.tol_hq) && param.delta >= param.tol) updateX = true; // force a reliable update based on the HQ residual if L2 breakdown has already happened - if (L2breakdown && (convergenceHQ(r2, hq_res, stop, param.tol_hq) || (r2 / b2) < hq_res_stall_check) + if (L2breakdown && (convergenceHQ(hq_res, param.tol_hq) || (r2[0] / b2[0]) < hq_res_stall_check) && param.delta >= param.tol) updateX = true; if (!(updateR || updateX)) { // No reliable update needed - beta = sigma / r2_old; // use the alternative beta computation + for (auto i = 0u; i < beta.size(); i++) beta[i] = sigma[i] / r2_old[i]; // use the alternative beta computation - blas::axpyZpbx(alpha, p, xSloppy, rSloppy, beta); + blas::axpyZpbx(alpha, p, x_sloppy, r_sloppy, beta); + + auto get_hq_res2 = [](cvector_ref &x, cvector_ref &y, + cvector_ref &r) { + auto hq_nrm = blas::xpyHeavyQuarkResidualNorm(x, y, r); + vector hq_res(hq_nrm.size()); + for (auto i = 0u; i < hq_nrm.size(); i++) hq_res[i] = sqrt(hq_res[i]); + return hq_res; + }; if (k % param.heavy_quark_check == 0) { - if (xSloppy.Precision() != rSloppy.Precision()) { - blas::copy(r, rSloppy); - hq_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, y, r).z); + if (param.precision != param.precision_sloppy) { + blas::copy(r, r_sloppy); + hq_res = get_hq_res2(x_sloppy, y, r); } else { - hq_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, y, rSloppy).z); + hq_res = get_hq_res2(x_sloppy, y, r_sloppy); } } @@ -762,23 +843,23 @@ namespace quda { // We're performing a reliable update // Accumulate p into x, accumulate x into the total solution y, explicitly recompute the residual vector - blas::axpy(alpha, p, xSloppy); - blas::copy(x, xSloppy); // no op when these pointers alias + blas::axpy(alpha, p, x_sloppy); + blas::copy(x, x_sloppy); // no op when these pointers alias blas::xpy(x, y); mat(r, y); // Recompute the exact residual and heavy quark residual r2 = blas::xmyNorm(b, r); - rNorm = sqrt(r2); - hq_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z); + rNorm = sqrt(r2[0]); + hq_res = get_hq_res(y, r); // Copy and update fields - blas::copy(rSloppy, r); // no op when these pointers alias - blas::zero(xSloppy); + blas::copy(r_sloppy, r); // no op when these pointers alias + blas::zero(x_sloppy); // Check and see if we're "done" with the L2 norm. This could be because // we were already done with it, we never needed it, or the L2 norm has finally converged. - if (!L2breakdown && convergenceL2(r2, hq_res, stop, param.tol_hq)) L2breakdown = true; + if (!L2breakdown && convergenceL2(r2, stop)) L2breakdown = true; // Depending on if we're still grinding on the L2 norm or if we've moved along to just // the HQ norm, we reset the baselines for reliable updates that get used on the @@ -792,8 +873,8 @@ namespace quda { } else { // If we've made it to the HQ norm, the new baseline is the freshly recomputed // heavy quark residual - maxrr = hq_res; - maxrx = hq_res; + maxrr = hq_res[0]; + maxrx = hq_res[0]; // Once we're dealing with the heavy quark residual, we perform a *hard* CG // restart at every reliable update via setting the search vector `p` to the current @@ -828,7 +909,7 @@ namespace quda { // ...tell the world about it too. warningQuda( "new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", - sqrt(r2), r0Norm, resIncreaseTotal); + sqrt(r2[0]), r0Norm, resIncreaseTotal); // If the norm is ridiculously small in magnitude, we've exceeded the maximums on various // ways we keep track of residual increases, or the L2 norm converged, we say "we're good here" @@ -840,8 +921,8 @@ namespace quda { // We also have to do a logic correction, switching the reliable update baselines we set above // from the L2 norm over to the HQ residual. - maxrr = hq_res; - maxrx = hq_res; + maxrr = hq_res[0]; + maxrx = hq_res[0]; } } else { // This variable counts the number of times in a row the computed residual has gone up, @@ -856,8 +937,8 @@ namespace quda { hqresIncrease++; // Tell the world about it - warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", hq_res, - hq0Res); + warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", + hq_res[0], hq0Res[0]); // And if it's increased too many times in a row, flunk out of the solve. if (hqresIncrease > param.max_hq_res_increase) { @@ -876,17 +957,18 @@ namespace quda { if (heavy_quark_restart) { // If we're in the HQ residual part of the solve, we just do a hard CG restart. logQuda(QUDA_DEBUG_VERBOSE, "HQ restart == hard CG restart\n"); - blas::copy(p, rSloppy); + blas::copy(p, r_sloppy); heavy_quark_restart = false; } else { // If we're still in the L2 norm part of the solve, we explicitly restore // the orthogonality of the gradient vector, recompute beta, update `p`, and carry on with our lives. logQuda(QUDA_DEBUG_VERBOSE, "Regular restart == explicit gradient vector re-orthogonalization\n"); - Complex rp = blas::cDotProduct(rSloppy, p) / (r2); - blas::caxpy(-rp, rSloppy, p); + auto rp = blas::cDotProduct(r_sloppy, p); + for (auto i = 0u; i < b.size(); i++) rp[i] / r2[i]; + blas::caxpy(-rp, r_sloppy, p); - beta = r2 / r2_old; - blas::xpayz(rSloppy, beta, p, p); + for (auto i = 0u; i < b.size(); i++) beta[i] = r2[i] / r2_old[i]; + blas::xpayz(r_sloppy, beta, p, p); } // Last, we increment the reliable update counter, reset the number of steps since the last reliable update, @@ -894,7 +976,7 @@ namespace quda { // reliable update. rUpdate++; steps_since_reliable = 0; - r0Norm = sqrt(r2); + r0Norm = sqrt(r2[0]); hq0Res = hq_res; } @@ -908,13 +990,13 @@ namespace quda { // check for recent enough reliable updates of the HQ residual if we use it // L2 is converged or precision maxed out for L2 - bool L2done = L2breakdown || convergenceL2(r2, hq_res, stop, param.tol_hq); + bool L2done = L2breakdown || convergenceL2(r2, stop); // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update - bool HQdone = (steps_since_reliable == 0 && param.delta > 0) && convergenceHQ(r2, hq_res, stop, param.tol_hq); + bool HQdone = (steps_since_reliable == 0 && param.delta > 0) && convergenceHQ(hq_res, param.tol_hq); converged = L2done && HQdone; } - blas::copy(x, xSloppy); + blas::copy(x, x_sloppy); blas::xpy(y, x); getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); @@ -929,8 +1011,12 @@ namespace quda { if (param.compute_true_res) { // compute the true residuals mat(r, x); - param.true_res = sqrt(blas::xmyNorm(b, r) / b2); - param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + auto true_r2 = blas::xmyNorm(b, r); + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) { + param.true_res[i] = sqrt(true_r2[i] / b2[i]); + param.true_res_hq[i] = sqrt(hq[i].z); + } } PrintSummary("CG", k, r2, b2, stop, param.tol_hq); diff --git a/lib/inv_gcr_quda.cpp b/lib/inv_gcr_quda.cpp index 6066f1d14c..aa9920a123 100644 --- a/lib/inv_gcr_quda.cpp +++ b/lib/inv_gcr_quda.cpp @@ -175,7 +175,7 @@ namespace quda { } } - ColorSpinorField &GCR::get_residual() + cvector_ref GCR::get_residual() { if (!init) errorQuda("No residual vector present"); if (param.compute_true_res) diff --git a/lib/inv_mr_quda.cpp b/lib/inv_mr_quda.cpp index b50c349c8f..aee322bd21 100644 --- a/lib/inv_mr_quda.cpp +++ b/lib/inv_mr_quda.cpp @@ -20,32 +20,31 @@ namespace quda } } - void MR::create(ColorSpinorField &x, const ColorSpinorField &b) + void MR::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); if (!init) { - ColorSpinorParam csParam(b); - csParam.create = QUDA_NULL_FIELD_CREATE; - - r = ColorSpinorField(csParam); + resize(r, b.size(), QUDA_NULL_FIELD_CREATE, b[0]); // now allocate sloppy fields + ColorSpinorParam csParam(b[0]); + csParam.create = QUDA_NULL_FIELD_CREATE; csParam.setPrecision(param.precision_sloppy); - Ar = ColorSpinorField(csParam); - x_sloppy = ColorSpinorField(csParam); + resize(Ar, b.size(), csParam); + resize(x_sloppy, b.size(), csParam); - bool mixed = param.precision != param.precision_sloppy; - - if (!mixed) csParam.create = QUDA_REFERENCE_FIELD_CREATE; - csParam.v = r.data(); - r_sloppy = ColorSpinorField(csParam); + if (param.precision != param.precision_sloppy) { // mixed precision + resize(r_sloppy, b.size(), csParam); + } else { + create_alias(r_sloppy, r); + } init = true; } // init } - ColorSpinorField &MR::get_residual() + cvector_ref MR::get_residual() { if (!init) errorQuda("No residual vector present"); if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); @@ -53,7 +52,7 @@ namespace quda return r; } - void MR::operator()(ColorSpinorField &x, const ColorSpinorField &b) + void MR::operator()(cvector_ref &x, cvector_ref &b) { if (param.maxiter == 0 || param.Nsteps == 0) { if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); @@ -64,11 +63,13 @@ namespace quda if (!param.is_preconditioner) getProfile().TPSTART(QUDA_PROFILE_COMPUTE); - double b2 = blas::norm2(b); // Save norm of b - double r2 = 0.0; // if zero source then we will exit immediately doing no work + vector b2 = blas::norm2(b); // Save norm of b + vector r2; // if zero source then we will exit immediately doing no work + if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { mat(r, x); r2 = blas::xmyNorm(b, r); // r = b - Ax0 + for (auto i = 0u; i < b.size(); i++) if (b2[i] == 0) b2[i] = r2[i]; } else { r2 = b2; blas::copy(r, b); @@ -76,18 +77,20 @@ namespace quda } blas::copy(r_sloppy, r); - // if invalid residual then convergence is set by iteration count only - double stop = param.residual_type == QUDA_INVALID_RESIDUAL ? 0.0 : b2 * param.tol * param.tol; + auto stop = stopping(param.tol, b2, param.residual_type); int iter = 0; int step = 0; bool converged = false; - PrintStats("MR", iter, r2, b2, 0.0); + PrintStats("MR", iter, r2, b2); while (!converged) { int k = 0; - double scale = 1.0; + vector scale(b.size(), 1.0); + vector scale_inv(b.size(), 1.0); + vector delta2(b.size(), param.delta * param.delta); + if ((node_parity + step) % 2 == 0 && param.schwarz_type == QUDA_MULTIPLICATIVE_SCHWARZ) { // for multiplicative Schwarz we alternate updates depending on node parity } else { @@ -95,24 +98,27 @@ namespace quda commGlobalReductionPush(param.global_reduction); // use local reductions for DD solver blas::zero(x_sloppy); // can get rid of this for a special first update kernel - double c2 = param.global_reduction == QUDA_BOOLEAN_TRUE ? r2 : blas::norm2(r); // c2 holds the initial r2 - scale = c2 > 0.0 ? sqrt(c2) : 1.0; - - // domain-wise normalization of the initial residual to prevent underflow - if (c2 > 0.0) { - blas::ax(1 / scale, r_sloppy); // can merge this with the prior copy - r2 = 1.0; // by definition by this is now true + auto c2 = param.global_reduction == QUDA_BOOLEAN_TRUE ? r2 : blas::norm2(r); // c2 holds the initial r2 + for (auto i = 0u; i < b.size(); i++) { + scale[i] = c2[i] > 0.0 ? sqrt(c2[i]) : 1.0; + scale_inv[i] = 1.0 / scale[i]; + // domain-wise normalization of the initial residual to prevent underflow + if (c2[i] > 0.0) r2[i] = 1.0; // by definition by this is now true } + blas::ax(scale_inv, r_sloppy); // can merge this with the prior copy - while (k < param.maxiter && r2 > param.delta * param.delta) { + while (k < param.maxiter && r2 > delta2) { matSloppy(Ar, r_sloppy); if (param.global_reduction) { - double4 Ar4 = blas::cDotProductNormAB(Ar, r_sloppy); - Complex alpha = Complex(Ar4.x, Ar4.y) / Ar4.z; - r2 = Ar4.w; - PrintStats("MR (inner)", iter, r2, b2, 0.0); + auto Ar4 = blas::cDotProductNormAB(Ar, r_sloppy); + vector alpha(b.size()); + for (auto i = 0u; i < b.size(); i++) { + alpha[i] = Complex(Ar4[i].x, Ar4[i].y) / Ar4[i].z; + r2[i] = Ar4[i].w; + } + PrintStats("MR (inner)", iter, r2, b2); // x += omega*alpha*r, r -= omega*alpha*Ar, r2 = blas::norm2(r) blas::caxpyXmaz(param.omega * alpha, r_sloppy, x_sloppy, Ar); @@ -140,7 +146,7 @@ namespace quda if (compute_true_res) { mat(r, x); r2 = blas::xmyNorm(b, r); - param.true_res = sqrt(r2 / b2); + for (auto i = 0u; i < b2.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); converged = (step < param.Nsteps && r2 > stop) ? false : true; if (!converged) blas::copy(r_sloppy, r); PrintStats("MR (restart)", iter, r2, b2, 0.0); diff --git a/lib/inv_sd_quda.cpp b/lib/inv_sd_quda.cpp index 484b06c7f2..5a1906ba01 100644 --- a/lib/inv_sd_quda.cpp +++ b/lib/inv_sd_quda.cpp @@ -11,55 +11,73 @@ namespace quda { SD::SD(const DiracMatrix &mat, SolverParam ¶m) : Solver(mat, mat, mat, mat, param) { } - void SD::create(ColorSpinorField &x, const ColorSpinorField &b) + void SD::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); if (!init) { - ColorSpinorParam csParam(b); - csParam.create = QUDA_NULL_FIELD_CREATE; - r = ColorSpinorField(csParam); - Ar = ColorSpinorField(csParam); + resize(r, b.size(), QUDA_NULL_FIELD_CREATE, b[0]); + resize(Ar, b.size(), QUDA_NULL_FIELD_CREATE, b[0]); init = true; } } - ColorSpinorField &SD::get_residual() + cvector_ref SD::get_residual() { if (!init) errorQuda("No residual vector present"); return r; } - void SD::operator()(ColorSpinorField &x, const ColorSpinorField &b) + void SD::operator()(cvector_ref &x, cvector_ref &b) { commGlobalReductionPush(param.global_reduction); create(x, b); - double b2 = blas::norm2(b); - double r2; + vector b2 = blas::norm2(b); + vector r2; + + // Check to see that we're not trying to invert on a zero-field source + if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { + bool zero_src = true; + for (auto i = 0u; i < b.size(); i++) { + if (b2[i] == 0) { + warningQuda("inverting on zero-field source"); + x[i] = b[i]; + param.true_res[i] = 0.0; + param.true_res_hq[i] = 0.0; + } else { + zero_src = false; + } + } + if (zero_src) return; + } + if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { // Compute the true residual mat(r, x); r2 = blas::xmyNorm(b, r); + for (auto i = 0u; i < b.size(); i++) if (b2[i] == 0) b2[i] = r2[i]; } else { blas::zero(x); blas::copy(r, b); r2 = b2; } - // if invalid residual then convergence is set by iteration count only - double stop = param.residual_type == QUDA_INVALID_RESIDUAL ? 0.0 : b2 * param.tol * param.tol; + auto stop = stopping(param.tol, b2, param.residual_type); int res_increase = 0; int k = 0; while (k < param.maxiter) { mat(Ar, r); - double3 rAr = blas::cDotProductNormA(r, Ar); - auto alpha = rAr.z / rAr.x; - r2 = rAr.z; // this is r2 from the prior iteration + auto rAr = blas::cDotProductNormA(r, Ar); + vector alpha(b.size()); + for (auto i = 0u; i < b.size(); i++) { + alpha[i] = rAr[i].z / rAr[i].x; + r2[i] = rAr[i].z; // this is r2 from the prior iteration + } - PrintStats("SD", k, r2, b2, 0.0); + PrintStats("SD", k, r2, b2); if (r2 < stop) { mat(r, x); @@ -80,11 +98,11 @@ namespace quda { if (param.compute_true_res) { // Compute the true residual mat(r, x); - double true_r2 = blas::xmyNorm(b, r); - PrintSummary("SD", k, true_r2, b2, 0.0, 0.0); - param.true_res = sqrt(true_r2 / b2); + auto true_r2 = blas::xmyNorm(b, r); + PrintSummary("SD", k, true_r2, b2, stop); + for (auto i = 0u; i < b2.size(); i++) param.true_res[i] = sqrt(true_r2[i] / b2[i]); } else { - PrintSummary("SD", k, r2, b2, 0.0, 0.0); + PrintSummary("SD", k, r2, b2, stop); } commGlobalReductionPop(); diff --git a/lib/solver.cpp b/lib/solver.cpp index 357e30f067..ee7278e041 100644 --- a/lib/solver.cpp +++ b/lib/solver.cpp @@ -34,10 +34,13 @@ namespace quda { } } - void Solver::create(ColorSpinorField &x, const ColorSpinorField &b) + void Solver::create(cvector_ref &x, cvector_ref &b) { if (checkPrecision(x, b) != param.precision) errorQuda("Precision mismatch %d %d", checkPrecision(x, b), param.precision); + + param.true_res.resize(b.size()); + param.true_res_hq.resize(b.size()); } // solver factory @@ -368,108 +371,125 @@ namespace quda { { for (int i = 0; i < param.num_src; i++) { (*this)(out.Component(i), in.Component(i)); - param.true_res_offset[i] = param.true_res; - param.true_res_hq_offset[i] = param.true_res_hq; + param.true_res_offset[i] = static_cast(param.true_res); + param.true_res_hq_offset[i] = static_cast(param.true_res_hq); } } - double Solver::stopping(double tol, double b2, QudaResidualType residual_type) + vector Solver::stopping(double tol, cvector &b2, QudaResidualType residual_type) { - double stop=0.0; + vector stop(b2.size(), 0.0); if ( (residual_type & QUDA_L2_ABSOLUTE_RESIDUAL) && (residual_type & QUDA_L2_RELATIVE_RESIDUAL) ) { - // use the most stringent stopping condition - double lowest = (b2 < 1.0) ? b2 : 1.0; - stop = lowest*tol*tol; + for (auto i = 0u; i < b2.size(); i++) { + // use the most stringent stopping condition + double lowest = (b2[i] < 1.0) ? b2[i] : 1.0; + stop[i] = lowest * tol * tol; + } } else if (residual_type & QUDA_L2_ABSOLUTE_RESIDUAL) { - stop = tol*tol; + for (auto i = 0u; i < b2.size(); i++) stop[i] = tol * tol; + } else if (residual_type & QUDA_L2_RELATIVE_RESIDUAL) { + for (auto i = 0u; i < b2.size(); i++) stop[i] = b2[i] * tol * tol; } else { - stop = b2*tol*tol; + // if invalid residual then convergence is set by iteration count only + for (auto i = 0u; i < b2.size(); i++) stop[i] = 0.0; } return stop; } - bool Solver::convergence(double r2, double hq2, double r2_tol, double hq_tol) { - - // check the heavy quark residual norm if necessary - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - if (std::isnan(hq2) || std::isinf(hq2)) - errorQuda("Solver appears to have diverged with heavy quark residual %9.6e", hq2); - - if (hq2 > hq_tol) return false; - } - - // check the L2 relative residual norm if necessary - if ((param.residual_type & QUDA_L2_RELATIVE_RESIDUAL) || (param.residual_type & QUDA_L2_ABSOLUTE_RESIDUAL)) { - if (std::isnan(r2) || std::isinf(r2)) errorQuda("Solver appears to have diverged with residual %9.6e", r2); + bool Solver::convergence(cvector &r2, cvector &hq2, cvector &r2_tol, cvector &hq_tol) + { + for (auto i = 0u; i < r2.size(); i++) { + // check the heavy quark residual norm if necessary + if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { + if (std::isnan(hq2[i]) || std::isinf(hq2[i])) + errorQuda("Solver appears to have diverged with heavy quark residual %9.6e", hq2[i]); + if (hq2[i] > hq_tol[i]) return false; + } - if (r2 > r2_tol) return false; + // check the L2 relative residual norm if necessary + if ((param.residual_type & QUDA_L2_RELATIVE_RESIDUAL) || (param.residual_type & QUDA_L2_ABSOLUTE_RESIDUAL)) { + if (std::isnan(r2[i]) || std::isinf(r2[i])) errorQuda("Solver appears to have diverged with residual %9.6e", r2[i]); + if (r2[i] > r2_tol[i]) return false; + } } - return true; } - bool Solver::convergenceHQ(double, double hq2, double, double hq_tol) + bool Solver::convergenceHQ(cvector &hq2, cvector &hq_tol) { - // check the heavy quark residual norm if necessary - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - if (std::isnan(hq2) || std::isinf(hq2)) - errorQuda("Solver appears to have diverged with heavy quark residual %9.6e", hq2); - - if (hq2 > hq_tol) return false; + for (auto i = 0u; i < hq2.size(); i++) { + // check the heavy quark residual norm if necessary + if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { + if (std::isnan(hq2[i]) || std::isinf(hq2[i])) + errorQuda("Solver appears to have diverged with heavy quark residual %9.6e", hq2[i]); + if (hq2[i] > hq_tol[i]) return false; + } } - return true; } - bool Solver::convergenceL2(double r2, double, double r2_tol, double) + bool Solver::convergenceL2(cvector &r2, cvector &r2_tol) { - // check the L2 relative residual norm if necessary - if ((param.residual_type & QUDA_L2_RELATIVE_RESIDUAL) || (param.residual_type & QUDA_L2_ABSOLUTE_RESIDUAL)) { - if (std::isnan(r2) || std::isinf(r2)) errorQuda("Solver appears to have diverged with residual %9.6e", r2); - - if (r2 > r2_tol) return false; + for (auto i = 0u; i < r2.size(); i++) { + // check the L2 relative residual norm if necessary + if ((param.residual_type & QUDA_L2_RELATIVE_RESIDUAL) || (param.residual_type & QUDA_L2_ABSOLUTE_RESIDUAL)) { + if (std::isnan(r2[i]) || std::isinf(r2[i])) errorQuda("Solver appears to have diverged with residual %9.6e", r2[i]); + if (r2[i] > r2_tol[i]) return false; + } } - return true; } - void Solver::PrintStats(const char* name, int k, double r2, double b2, double hq2) { - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - logQuda(QUDA_VERBOSE, "%s: %5d iterations, = %9.6e, |r|/|b| = %9.6e, heavy-quark residual = %9.6e\n", name, - k, r2, sqrt(r2 / b2), hq2); - } else { - logQuda(QUDA_VERBOSE, "%s: %5d iterations, = %9.6e, |r|/|b| = %9.6e\n", name, k, r2, sqrt(r2 / b2)); - } - - if (std::isnan(r2) || std::isinf(r2)) errorQuda("Solver appears to have diverged"); + std::string set_rhs_str(unsigned int i, size_t n) + { + std::string rhs_str; + if (n > 1) rhs_str += "n = " + std::to_string(i) + std::string(", "); + return rhs_str; } - void Solver::PrintSummary(const char *name, int k, double r2, double b2, - double r2_tol, double hq_tol) { - if (param.compute_true_res) { + void Solver::PrintStats(const char* name, int k, cvector &r2, cvector &b2, cvector &hq2) { + for (auto i = 0u; i < r2.size(); i++) { + auto rhs_str = set_rhs_str(i, r2.size()); if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - logQuda(QUDA_SUMMARIZE, - "%s: Convergence at %d iterations, L2 relative residual: iterated = %9.6e, true = %9.6e " - "(requested = %9.6e), heavy-quark residual = %9.6e (requested = %9.6e)\n", - name, k, sqrt(r2 / b2), param.true_res, sqrt(r2_tol / b2), param.true_res_hq, hq_tol); + logQuda(QUDA_VERBOSE, "%s: %5d iterations, %s = %9.6e, |r|/|b| = %9.6e, heavy-quark residual = %9.6e\n", name, + k, rhs_str.c_str(), r2[i], sqrt(r2[i] / b2[i]), hq2[i]); } else { - logQuda(QUDA_SUMMARIZE, - "%s: Convergence at %d iterations, L2 relative residual: iterated = %9.6e, true = %9.6e " - "(requested = %9.6e)\n", - name, k, sqrt(r2 / b2), param.true_res, sqrt(r2_tol / b2)); + logQuda(QUDA_VERBOSE, "%s: %5d iterations, %s = %9.6e, |r|/|b| = %9.6e\n", name, k, rhs_str.c_str(), r2[i], sqrt(r2[i] / b2[i])); } - } else { - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - logQuda(QUDA_SUMMARIZE, - "%s: Convergence at %d iterations, L2 relative residual: iterated = %9.6e " - "(requested = %9.6e), heavy-quark residual = %9.6e (requested = %9.6e)\n", - name, k, sqrt(r2 / b2), sqrt(r2_tol / b2), param.true_res_hq, hq_tol); + + if (std::isnan(r2[i]) || std::isinf(r2[i])) errorQuda("Solver appears to have diverged for n = %d", i); + } + } + + void Solver::PrintSummary(const char *name, int k, cvector &r2, cvector &b2, + cvector &r2_tol, cvector &hq_tol) { + for (auto i = 0u; i < r2.size(); i++) { + auto rhs_str = set_rhs_str(i, r2.size()); + if (param.compute_true_res) { + if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { + logQuda(QUDA_SUMMARIZE, + "%s: Convergence at %d iterations, %sL2 relative residual: iterated = %9.6e, true = %9.6e " + "(requested = %9.6e), heavy-quark residual = %9.6e (requested = %9.6e)\n", + name, k, rhs_str.c_str(), sqrt(r2[i] / b2[i]), param.true_res[i], sqrt(r2_tol[i] / b2[i]), param.true_res_hq[i], hq_tol[i]); + } else { + logQuda(QUDA_SUMMARIZE, + "%s: Convergence at %d iterations, %sL2 relative residual: iterated = %9.6e, true = %9.6e " + "(requested = %9.6e)\n", + name, k, rhs_str.c_str(), sqrt(r2[i] / b2[i]), param.true_res[i], sqrt(r2_tol[i] / b2[i])); + } } else { - logQuda(QUDA_SUMMARIZE, - "%s: Convergence at %d iterations, L2 relative residual: iterated = %9.6e (requested = %9.6e)\n", name, - k, sqrt(r2 / b2), sqrt(r2_tol / b2)); + if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { + logQuda(QUDA_SUMMARIZE, + "%s: Convergence at %d iterations, %sL2 relative residual: iterated = %9.6e " + "(requested = %9.6e), heavy-quark residual = %9.6e (requested = %9.6e)\n", + name, k, rhs_str.c_str(), sqrt(r2[i] / b2[i]), sqrt(r2_tol[i] / b2[i]), param.true_res_hq[i], hq_tol[i]); + } else { + logQuda(QUDA_SUMMARIZE, + "%s: Convergence at %d iterations, %sL2 relative residual: iterated = %9.6e (requested = %9.6e)\n", name, + k, rhs_str.c_str(), sqrt(r2[i] / b2[i]), sqrt(r2_tol[i] / b2[i])); + } } } } diff --git a/tests/dslash_test_utils.h b/tests/dslash_test_utils.h index 914edb247d..837b0c91d3 100644 --- a/tests/dslash_test_utils.h +++ b/tests/dslash_test_utils.h @@ -319,8 +319,10 @@ struct DslashTestWrapper { dirac = Dirac::create(diracParam); } else { - double cpu_norm = blas::norm2(spinor); - printfQuda("Source: CPU = %e\n", cpu_norm); + for (int i = 0; i < Nsrc; i++) { + double cpu_norm = blas::norm2(spinor[i]); + printfQuda("Source %d: CPU = %e\n", i, cpu_norm); + } } } From 224bdb2bcad336d108c3568c5bb198c27969955e Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 1 Jul 2024 10:48:28 -0700 Subject: [PATCH 016/103] Accelerate MG::verify by using batch blas where applicable --- lib/multigrid.cpp | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/lib/multigrid.cpp b/lib/multigrid.cpp index 5d7b9db979..b07a007238 100644 --- a/lib/multigrid.cpp +++ b/lib/multigrid.cpp @@ -789,6 +789,7 @@ namespace quda auto &tmp2 = fine_tmp[1]; auto &tmp_coarse = coarse_tmp[0]; + auto B_norm = norm2(param.B); // No need to check (projector) v_k for staggered case if (param.transfer_type == QUDA_TRANSFER_AGGREGATE) { @@ -799,17 +800,19 @@ namespace quda transfer->R(coarse_tmp, param.B); transfer->P(fine_tmp, coarse_tmp); + auto max_deviation = blas::max_deviation(param.B, fine_tmp); + auto deviation = xmyNorm(param.B, fine_tmp); + auto coarse_norm = norm2(coarse_tmp); + auto fine_norm = norm2(coarse_tmp); for (auto i = 0; i < param.Nvec; i++) { - auto max_deviation = blas::max_deviation(param.B[i], fine_tmp[i]); - auto l2_deviation = sqrt(xmyNorm(param.B[i], fine_tmp[i]) / norm2(param.B[i])); - - logQuda(QUDA_VERBOSE, - "Vector %d: L2 norms v_k = %e P^\\dagger v_k = %e (1 - P P^\\dagger) v_k = %e; Deviations: L2 relative = %e, max = %e\n", - i, norm2(param.B[i]), norm2(coarse_tmp[i]), norm2(fine_tmp[i]), l2_deviation, max_deviation[0]); + auto l2_deviation = sqrt(deviation[i]) / B_norm[i]; + logQuda( + QUDA_VERBOSE, "Vector %d: L2 norms v_k = %e P^\\dagger v_k = %e (1 - P P^\\dagger) v_k = %e; Deviations: L2 relative = %e, max = %e\n", + i, B_norm[i], coarse_norm[i], fine_norm[i], l2_deviation, max_deviation[i][0]); if (check_deviation(l2_deviation, tol)) errorQuda("k=%d orthonormality failed: L2 relative deviation %e > %e", i, l2_deviation, tol); - if (check_deviation(max_deviation[0], tol)) - errorQuda("k=%d orthonormality failed: max deviation %e > %e", i, max_deviation[0], tol); + if (check_deviation(max_deviation[i][0], tol)) + errorQuda("k=%d orthonormality failed: max deviation %e > %e", i, max_deviation[i][0], tol); } for (auto &f : fine_tmp) f.GammaBasis(r.GammaBasis()); // restore basis @@ -829,8 +832,8 @@ namespace quda transfer->P(tmp2, x_coarse); (*param.matResidual)(tmp1, tmp2); tmp2 = param.B[i]; - logQuda(QUDA_SUMMARIZE, "Vector %d: norms %e %e\n", i, norm2(param.B[i]), norm2(tmp1)); - logQuda(QUDA_SUMMARIZE, "relative residual = %e\n", sqrt(xmyNorm(tmp2, tmp1) / norm2(param.B[i]))); + logQuda(QUDA_SUMMARIZE, "Vector %d: norms %e %e\n", i, B_norm[i], norm2(tmp1)); + logQuda(QUDA_SUMMARIZE, "relative residual = %e\n", sqrt(xmyNorm(tmp2, tmp1) / B_norm[i])); } sprintf(prefix, "MG level %d (%s): ", param.level + 1, param.location == QUDA_CUDA_FIELD_LOCATION ? "GPU" : "CPU"); @@ -849,8 +852,8 @@ namespace quda transfer->P(tmp2, x_coarse); param.matResidual(tmp1, tmp2); tmp2 = param.B[i]; - logQuda(QUDA_SUMMARIZE, "Vector %d: norms %e %e ", i, norm2(param.B[i]), norm2(tmp1)); - logQuda(QUDA_SUMMARIZE, "relative residual = %e\n", sqrt(xmyNorm(tmp2, tmp1) / norm2(param.B[i])) ); + logQuda(QUDA_SUMMARIZE, "Vector %d: norms %e %e ", i, B_norm[i], norm2(tmp1)); + logQuda(QUDA_SUMMARIZE, "relative residual = %e\n", sqrt(xmyNorm(tmp2, tmp1) / B_norm[i]) ); } #endif @@ -1088,12 +1091,12 @@ namespace quda // Prolong r_coarse, place result in tmp2 transfer->P(tmp2, r_coarse); - printfQuda("Vector %d: norms v_k = %e P^dag v_k = %e PP^dag v_k = %e\n", i, norm2(param.B[i]), - norm2(r_coarse), norm2(tmp2)); + printfQuda("Vector %d: norms v_k = %e P^dag v_k = %e PP^dag v_k = %e\n", i, B_norm[i], norm2(r_coarse), + norm2(tmp2)); // Compare v_k and PP^dag v_k. auto max_deviation = blas::max_deviation(tmp2, param.B[i]); - auto l2_deviation = sqrt(xmyNorm(param.B[i], tmp2) / norm2(param.B[i])); + auto l2_deviation = sqrt(xmyNorm(param.B[i], tmp2) / B_norm[i]); printfQuda("L2 relative deviation = %e max deviation = %e\n", l2_deviation, max_deviation[0]); if (param.mg_global.run_oblique_proj_check) { @@ -1111,11 +1114,10 @@ namespace quda transfer->P(tmp2, x_coarse); (*param.matResidual)(tmp1, tmp2); - logQuda(QUDA_SUMMARIZE, "Vector %d: norms v_k %e DP(P^dagDP)P^dag v_k %e\n", i, norm2(param.B[i]), - norm2(tmp1)); + logQuda(QUDA_SUMMARIZE, "Vector %d: norms v_k %e DP(P^dagDP)P^dag v_k %e\n", i, B_norm[i], norm2(tmp1)); max_deviation = blas::max_deviation(tmp1, param.B[i]); logQuda(QUDA_SUMMARIZE, "L2 relative deviation = %e, max deviation = %e\n", - sqrt(xmyNorm(param.B[i], tmp1) / norm2(param.B[i])), max_deviation[0]); + sqrt(xmyNorm(param.B[i], tmp1) / B_norm[i]), max_deviation[0]); } sprintf(prefix, "MG level %d (%s): ", param.level + 1, From 6256391d6fdb0910375050315cc9f2d69247c322 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 1 Jul 2024 11:03:14 -0700 Subject: [PATCH 017/103] Fix bug in MRE solver --- lib/inv_mre.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/inv_mre.cpp b/lib/inv_mre.cpp index 2ef194ba4e..f2d225a70b 100644 --- a/lib/inv_mre.cpp +++ b/lib/inv_mre.cpp @@ -128,7 +128,7 @@ namespace quda // compute the residual only if we're going to print it ColorSpinorField r(b); for (auto &a : alpha) a = -a; - blas::caxpy(alpha, q, r); + blas::block::caxpy(alpha, q, r); printfQuda("MinResExt: N = %d, |res| / |src| = %e\n", N, sqrt(blas::norm2(r) / blas::norm2(b))); } From 7962dc3afe4523572356e38a94676702e00e7e44 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 1 Jul 2024 11:16:58 -0700 Subject: [PATCH 018/103] Apply MRHS optimization to MRE solver --- lib/interface_quda.cpp | 2 +- lib/inv_mre.cpp | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 370e99f5ae..914593fe2d 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -4040,7 +4040,7 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) { CG cg(*m, *mSloppy, *mSloppy, *mSloppy, solverParam); - if (i==0) + if (i == 0) cg(x[i], b, p[i], r2_old[i]); else cg(x[i], b); diff --git a/lib/inv_mre.cpp b/lib/inv_mre.cpp index f2d225a70b..f8d078a677 100644 --- a/lib/inv_mre.cpp +++ b/lib/inv_mre.cpp @@ -106,8 +106,7 @@ namespace quda } // if operator hasn't already been applied then apply - if (apply_mat) - for (int i = 0; i < N; i++) mat(q[i], p[i]); + if (apply_mat) mat(q, p); // Solution coefficient vectors std::vector alpha(N); From 075cfb89cb168a727ddc0b66684f83911c86aa69 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 2 Jul 2024 23:04:23 -0700 Subject: [PATCH 019/103] Remove complex.h inclusion --- lib/interface_quda.cpp | 1 - lib/targets/cuda/blas_lapack_cublas.cpp | 1 - 2 files changed, 2 deletions(-) diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 914593fe2d..2911c1786e 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include diff --git a/lib/targets/cuda/blas_lapack_cublas.cpp b/lib/targets/cuda/blas_lapack_cublas.cpp index 024d7f2fcb..9db0266f92 100644 --- a/lib/targets/cuda/blas_lapack_cublas.cpp +++ b/lib/targets/cuda/blas_lapack_cublas.cpp @@ -1,4 +1,3 @@ -#include #include #include #ifdef NATIVE_LAPACK_LIB From 7cbab278ccc532e3e46e9e27c342ba45fdbed89f Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 3 Jul 2024 17:58:34 -0700 Subject: [PATCH 020/103] Vectorize all remaining Dirac prepare/reconstruct functions --- lib/dirac_clover.cpp | 53 ++++++------- lib/dirac_domain_wall.cpp | 31 +++----- lib/dirac_domain_wall_4d.cpp | 53 ++++++------- lib/dirac_improved_staggered.cpp | 53 ++++++------- lib/dirac_mobius.cpp | 120 +++++++++++++---------------- lib/dirac_staggered.cpp | 61 +++++++-------- lib/dirac_twisted_clover.cpp | 50 ++++++------ lib/dirac_twisted_mass.cpp | 126 ++++++++++++++----------------- lib/dirac_wilson.cpp | 1 + 9 files changed, 238 insertions(+), 310 deletions(-) diff --git a/lib/dirac_clover.cpp b/lib/dirac_clover.cpp index a7f8702e79..8fdf896552 100644 --- a/lib/dirac_clover.cpp +++ b/lib/dirac_clover.cpp @@ -82,10 +82,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracClover::reconstruct(cvector_ref &, cvector_ref &, @@ -225,30 +223,25 @@ namespace quda { { // we desire solution to preconditioned system if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + // we desire solution to full system - auto tmp = getFieldTmp(b[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - if (symmetric) { - // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) - src[i] = x[i][other_parity].create_alias(); - CloverInv(src[i], b[i][other_parity], other_parity); - DiracWilson::DslashXpay(tmp, src[i], this_parity, b[i][this_parity], kappa); - CloverInv(src[i], tmp, this_parity); - sol[i] = x[i][this_parity].create_alias(); - } else { - // src = b_e + k D_eo A_oo^-1 b_o - src[i] = x[i][other_parity].create_alias(); - CloverInv(tmp, b[i][other_parity], other_parity); // safe even when tmp = b.odd - DiracWilson::DslashXpay(src[i], tmp, this_parity, b[this_parity], kappa); - sol[i] = x[i][this_parity].create_alias(); - } + auto tmp = getFieldTmp(x.Even()); + if (symmetric) { + // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) + CloverInv(src, b(other_parity), other_parity); + DiracWilson::DslashXpay(tmp, src, this_parity, b(this_parity), kappa); + CloverInv(src, tmp, this_parity); + } else { + // src = b_e + k D_eo A_oo^-1 b_o + CloverInv(tmp, b(other_parity), other_parity); // safe even when tmp = b.odd + DiracWilson::DslashXpay(src, tmp, this_parity, b(this_parity), kappa); } } @@ -257,13 +250,11 @@ namespace quda { { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; - auto tmp = getFieldTmp(b[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - // x_o = A_oo^-1 (b_o + k D_oe x_e) - DiracWilson::DslashXpay(tmp, x[i][this_parity], other_parity, b[i][other_parity], kappa); - CloverInv(x[i][other_parity], tmp, other_parity); - } + auto tmp = getFieldTmp(x.Even()); + checkFullSpinor(x, b); + // x_o = A_oo^-1 (b_o + k D_oe x_e) + DiracWilson::DslashXpay(tmp, x(this_parity), other_parity, b(other_parity), kappa); + CloverInv(x(other_parity), tmp, other_parity); } void DiracCloverPC::createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double, double mu, diff --git a/lib/dirac_domain_wall.cpp b/lib/dirac_domain_wall.cpp index a7ad1581b4..6e5886116d 100644 --- a/lib/dirac_domain_wall.cpp +++ b/lib/dirac_domain_wall.cpp @@ -84,10 +84,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracDomainWall::reconstruct(cvector_ref &, cvector_ref &, @@ -152,20 +150,17 @@ namespace quda { const QudaSolutionType solType) const { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + // we desire solution to full system - for (auto i = 0u; i < b.size(); i++) { - // src = b_e + k D_eo b_o - DslashXpay(x[i][other_parity], b[i][other_parity], this_parity, b[this_parity], kappa5); - src[i] = x[i][other_parity].create_alias(); - sol[i] = x[i][this_parity].create_alias(); - } + // src = b_e + k D_eo b_o + DslashXpay(x(other_parity), b(other_parity), this_parity, b(this_parity), kappa5); } void DiracDomainWallPC::reconstruct(cvector_ref &x, cvector_ref &b, @@ -174,11 +169,9 @@ namespace quda { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; // create full solution - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - // x_o = b_o + k D_oe x_e - DslashXpay(x[i][other_parity], x[i][this_parity], other_parity, b[i][other_parity], kappa5); - } + checkFullSpinor(x, b); + // x_o = b_o + k D_oe x_e + DslashXpay(x(other_parity), x(this_parity), other_parity, b(other_parity), kappa5); } } // namespace quda diff --git a/lib/dirac_domain_wall_4d.cpp b/lib/dirac_domain_wall_4d.cpp index bfa5e980e4..fe5b4e943e 100644 --- a/lib/dirac_domain_wall_4d.cpp +++ b/lib/dirac_domain_wall_4d.cpp @@ -86,10 +86,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracDomainWall4D::reconstruct(cvector_ref &, cvector_ref &, @@ -171,30 +169,25 @@ namespace quda { const QudaSolutionType solType) const { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + // we desire solution to full system - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - if (symmetric) { - // src = M5^-1 (b_e + k D4_eo*M5^-1 b_o) - src[i] = x[i][other_parity].create_alias(); - M5inv(src[i], b[i][other_parity]); - Dslash4Xpay(tmp, src[i], this_parity, b[i][this_parity], kappa5); - M5inv(src[i], tmp); - sol[i] = x[i][this_parity].create_alias(); - } else { - // src = b_e + k D4_eo*M5^-1 b_o - src[i] = x[i][other_parity].create_alias(); - M5inv(tmp, b[i][other_parity]); - Dslash4Xpay(src[i], tmp, this_parity, b[i][this_parity], kappa5); - sol[i] = x[i][this_parity].create_alias(); - } + auto tmp = getFieldTmp(x.Even()); + if (symmetric) { + // src = M5^-1 (b_e + k D4_eo*M5^-1 b_o) + M5inv(src, b(other_parity)); + Dslash4Xpay(tmp, src, this_parity, b(this_parity), kappa5); + M5inv(src, tmp); + } else { + // src = b_e + k D4_eo*M5^-1 b_o + M5inv(tmp, b(other_parity)); + Dslash4Xpay(src, tmp, this_parity, b(this_parity), kappa5); } } @@ -204,13 +197,11 @@ namespace quda { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; // create full solution - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - // x_o = M5^-1 (b_o + k D4_oe x_e) - Dslash4Xpay(tmp, x[i][this_parity], other_parity, b[i][other_parity], kappa5); - M5inv(x[i][other_parity], tmp); - } + auto tmp = getFieldTmp(x.Even()); + checkFullSpinor(x, b); + // x_o = M5^-1 (b_o + k D4_oe x_e) + Dslash4Xpay(tmp, x(this_parity), other_parity, b(other_parity), kappa5); + M5inv(x(other_parity), tmp); } } // end namespace quda diff --git a/lib/dirac_improved_staggered.cpp b/lib/dirac_improved_staggered.cpp index b5e2b0d326..82f6537fe2 100644 --- a/lib/dirac_improved_staggered.cpp +++ b/lib/dirac_improved_staggered.cpp @@ -93,10 +93,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracImprovedStaggered::reconstruct(cvector_ref &, cvector_ref &, @@ -212,18 +210,17 @@ namespace quda { return; } - for (auto i = 0u; i < b.size(); i++) { - // we desire solution to full system. - // With the convention given in DiracStaggered::M(), - // the source is src = 2m b_e + D_eo b_o - // But remember, DslashXpay actually applies - // -D_eo. Flip the sign on 2m to compensate, and - // then flip the overall sign. - src[i] = x[i][other_parity].create_alias(); - DslashXpay(src[i], b[i][other_parity], this_parity, b[i][this_parity], -2.0 * mass); - blas::ax(-1.0, src[i]); - sol[i] = x[i][this_parity].create_alias(); - } + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + + // we desire solution to full system. + // With the convention given in DiracStaggered::M(), + // the source is src = 2m b_e + D_eo b_o + // But remember, DslashXpay actually applies + // -D_eo. Flip the sign on 2m to compensate, and + // then flip the overall sign. + DslashXpay(src, b(other_parity), this_parity, b(this_parity), -2.0 * mass); + blas::ax(-1.0, src); } void DiracImprovedStaggeredPC::reconstruct(cvector_ref &x, cvector_ref &b, @@ -233,19 +230,17 @@ namespace quda { return; } - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - - // create full solution - // With the convention given in DiracStaggered::M(), - // the reconstruct is x_o = 1/(2m) (b_o + D_oe x_e) - // But remember: DslashXpay actually applies -D_oe, - // so just like above we need to flip the sign - // on b_o. We then correct this by applying an additional - // minus sign when we rescale by 2m. - DslashXpay(x[i][other_parity], x[i][this_parity], other_parity, b[i][other_parity], -1.0); - blas::ax(-0.5 / mass, x[i][other_parity]); - } + checkFullSpinor(x, b); + + // create full solution + // With the convention given in DiracStaggered::M(), + // the reconstruct is x_o = 1/(2m) (b_o + D_oe x_e) + // But remember: DslashXpay actually applies -D_oe, + // so just like above we need to flip the sign + // on b_o. We then correct this by applying an additional + // minus sign when we rescale by 2m. + DslashXpay(x(other_parity), x(this_parity), other_parity, b(other_parity), -1.0); + blas::ax(-0.5 / mass, x(other_parity)); } void DiracImprovedStaggeredPC::createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double, double mass, diff --git a/lib/dirac_mobius.cpp b/lib/dirac_mobius.cpp index 7e59dded58..c5567f435c 100644 --- a/lib/dirac_mobius.cpp +++ b/lib/dirac_mobius.cpp @@ -136,10 +136,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracMobius::reconstruct(cvector_ref &, cvector_ref &, @@ -364,49 +362,42 @@ namespace quda { const QudaSolutionType solType) const { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + // we desire solution to full system - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - if (symmetric) { - // src = D5^-1 (b_e + k D4_eo * D4pre * D5^-1 b_o) - src[i] = x[i][other_parity].create_alias(); - M5inv(tmp, b[i][other_parity]); - Dslash4pre(src[i], tmp); - Dslash4Xpay(tmp, src[i], this_parity, b[i][this_parity], 1.0); - M5inv(src[i], tmp); - sol[i] = x[i][this_parity].create_alias(); - } else { - // src = b_e + k D4_eo * D4pre * D5inv b_o - src[i] = x[i][other_parity].create_alias(); - M5inv(src[i], b[i][other_parity]); - Dslash4pre(tmp, src[i]); - Dslash4Xpay(src[i], tmp, this_parity, b[i][this_parity], 1.0); - sol[i] = x[i][this_parity].create_alias(); - } + auto tmp = getFieldTmp(x.Even()); + if (symmetric) { + // src = D5^-1 (b_e + k D4_eo * D4pre * D5^-1 b_o) + M5inv(tmp, b(other_parity)); + Dslash4pre(src, tmp); + Dslash4Xpay(tmp, src, this_parity, b(this_parity), 1.0); + M5inv(src, tmp); + } else { + // src = b_e + k D4_eo * D4pre * D5inv b_o + M5inv(src, b(other_parity)); + Dslash4pre(tmp, src); + Dslash4Xpay(src, tmp, this_parity, b(this_parity), 1.0); } } void DiracMobiusPC::reconstruct(cvector_ref &x, cvector_ref &b, const QudaSolutionType solType) const { - if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { return; } + if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; // create full solution - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - // psi_o = M5^-1 (b_o + k_b D4_oe D4pre x_e) - Dslash4pre(x[i][other_parity], x[i][this_parity]); - Dslash4Xpay(tmp, x[i][other_parity], other_parity, b[i][other_parity], 1.0); - M5inv(x[i][other_parity], tmp); - } + auto tmp = getFieldTmp(x.Even()); + checkFullSpinor(x, b); + // psi_o = M5^-1 (b_o + k_b D4_oe D4pre x_e) + Dslash4pre(x(other_parity), x(this_parity)); + Dslash4Xpay(tmp, x(other_parity), other_parity, b(other_parity), 1.0); + M5inv(x(other_parity), tmp); } void DiracMobiusPC::MdagMLocal(cvector_ref &out, cvector_ref &in) const @@ -592,10 +583,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracMobiusEofa::reconstruct(cvector_ref &, cvector_ref &, @@ -676,32 +665,27 @@ namespace quda { const QudaSolutionType solType) const { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + // we desire solution to full system - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - if (symmetric) { - // src = D5^-1 (b_e + k D4_eo * D4pre * D5^-1 b_o) - src[i] = x[i][other_parity].create_alias(); - m5inv_eofa(tmp, b[i][other_parity]); - Dslash4pre(src[i], tmp); - Dslash4Xpay(tmp, src[i], this_parity, b[i][this_parity], 1.0); - m5inv_eofa(src[i], tmp); - sol[i] = x[i][this_parity].create_alias(); - } else if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { - // src = b_e + k D4_eo * D4pre * D5inv b_o - src[i] = x[i][other_parity].create_alias(); - m5inv_eofa(src[i], b[i][other_parity]); - Dslash4pre(tmp, src[i]); - Dslash4Xpay(src[i], tmp, this_parity, b[i][this_parity], 1.0); - sol[i] = x[i][this_parity].create_alias(); - } + auto tmp = getFieldTmp(x.Even()); + if (symmetric) { + // src = D5^-1 (b_e + k D4_eo * D4pre * D5^-1 b_o) + m5inv_eofa(tmp, b(other_parity)); + Dslash4pre(src, tmp); + Dslash4Xpay(tmp, src, this_parity, b(this_parity), 1.0); + m5inv_eofa(src, tmp); + } else if (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) { + // src = b_e + k D4_eo * D4pre * D5inv b_o + m5inv_eofa(src, b(other_parity)); + Dslash4pre(tmp, src); + Dslash4Xpay(src, tmp, this_parity, b(this_parity), 1.0); } } @@ -711,14 +695,12 @@ namespace quda { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; // create full solution - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - // psi_o = M5^-1 (b_o + k_b D4_oe D4pre x_e) - Dslash4pre(x[i][other_parity], x[i][this_parity]); - Dslash4Xpay(tmp, x[i][other_parity], other_parity, b[i][other_parity], 1.0); - m5inv_eofa(x[i][other_parity], tmp); - } + auto tmp = getFieldTmp(x.Even()); + checkFullSpinor(x, b); + // psi_o = M5^-1 (b_o + k_b D4_oe D4pre x_e) + Dslash4pre(x(other_parity), x(this_parity)); + Dslash4Xpay(tmp, x(other_parity), other_parity, b(other_parity), 1.0); + m5inv_eofa(x(other_parity), tmp); } void DiracMobiusEofaPC::MdagM(cvector_ref &out, cvector_ref &in) const diff --git a/lib/dirac_staggered.cpp b/lib/dirac_staggered.cpp index c850733ce6..7a034bf3b8 100644 --- a/lib/dirac_staggered.cpp +++ b/lib/dirac_staggered.cpp @@ -88,10 +88,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracStaggered::reconstruct(cvector_ref &, cvector_ref &, @@ -208,26 +206,23 @@ namespace quda { const QudaSolutionType solType) const { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - // we desire solution to preconditioned system - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + // we desire solution to preconditioned system + create_alias(src, b); + create_alias(sol, x); return; } - for (auto i = 0u; i < b.size(); i++) { - // we desire solution to full system. - // With the convention given in DiracStaggered::M(), - // the source is src = 2m b_e + D_eo b_o - // But remember, DslashXpay actually applies - // -D_eo. Flip the sign on 2m to compensate, and - // then flip the overall sign. - src[i] = x[i][other_parity].create_alias(); - DslashXpay(src[i], b[i][other_parity], this_parity, b[i][this_parity], -2.0 * mass); - blas::ax(-1.0, src[i]); - sol[i] = x[i][this_parity].create_alias(); - } + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); + + // we desire solution to full system. + // With the convention given in DiracStaggered::M(), + // the source is src = 2m b_e + D_eo b_o + // But remember, DslashXpay actually applies + // -D_eo. Flip the sign on 2m to compensate, and + // then flip the overall sign. + DslashXpay(src, b(other_parity), this_parity, b(this_parity), -2.0 * mass); + blas::ax(-1.0, src); } void DiracStaggeredPC::reconstruct(cvector_ref &x, cvector_ref &b, @@ -237,19 +232,17 @@ namespace quda { return; } - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - - // create full solution - // With the convention given in DiracStaggered::M(), - // the reconstruct is x_o = 1/(2m) (b_o + D_oe x_e) - // But remember: DslashXpay actually applies -D_oe, - // so just like above we need to flip the sign - // on b_o. We then correct this by applying an additional - // minus sign when we rescale by 2m. - DslashXpay(x[i][other_parity], x[i][this_parity], other_parity, b[i][other_parity], -1.0); - blas::ax(-0.5 / mass, x[i][other_parity]); - } + checkFullSpinor(x, b); + + // create full solution + // With the convention given in DiracStaggered::M(), + // the reconstruct is x_o = 1/(2m) (b_o + D_oe x_e) + // But remember: DslashXpay actually applies -D_oe, + // so just like above we need to flip the sign + // on b_o. We then correct this by applying an additional + // minus sign when we rescale by 2m. + DslashXpay(x(other_parity), x(this_parity), other_parity, b(other_parity), -1.0); + blas::ax(-0.5 / mass, x(other_parity)); } void DiracStaggeredPC::createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double, double mass, double, diff --git a/lib/dirac_twisted_clover.cpp b/lib/dirac_twisted_clover.cpp index cb6b46bdb1..ca630da80c 100644 --- a/lib/dirac_twisted_clover.cpp +++ b/lib/dirac_twisted_clover.cpp @@ -125,10 +125,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracTwistedClover::reconstruct(cvector_ref &, cvector_ref &, @@ -295,31 +293,27 @@ namespace quda { { // we desire solution to preconditioned system if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } // we desire solution to full system - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - src[i] = x[i][other_parity].create_alias(); - sol[i] = x[i][this_parity].create_alias(); - - TwistCloverInv(!symmetric ? static_cast(tmp) : src[i], b[i][other_parity], other_parity); + auto tmp = getFieldTmp(x.Even()); + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); - if (symmetric) { - // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) - WilsonDslashXpay(tmp, src[i], this_parity, b[i][this_parity], kappa); - } else { - // src = b_e + k D_eo A_oo^-1 b_o - WilsonDslashXpay(src[i], tmp, this_parity, b[i][this_parity], kappa); - } + TwistCloverInv(!symmetric ? tmp : src, b(other_parity), other_parity); - if (symmetric) TwistCloverInv(src[i], tmp, this_parity); + if (symmetric) { + // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) + WilsonDslashXpay(tmp, src, this_parity, b(this_parity), kappa); + } else { + // src = b_e + k D_eo A_oo^-1 b_o + WilsonDslashXpay(src, tmp, this_parity, b(this_parity), kappa); } + + if (symmetric) TwistCloverInv(src, tmp, this_parity); } void DiracTwistedCloverPC::reconstruct(cvector_ref &x, cvector_ref &b, @@ -327,13 +321,11 @@ namespace quda { { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - // x_o = A_oo^-1 (b_o + k D_oe x_e) - WilsonDslashXpay(tmp, x[i][this_parity], other_parity, b[i][other_parity], kappa); - TwistCloverInv(x[i][other_parity], tmp, other_parity); - } + auto tmp = getFieldTmp(x.Even()); + checkFullSpinor(x, b); + // x_o = A_oo^-1 (b_o + k D_oe x_e) + WilsonDslashXpay(tmp, x(this_parity), other_parity, b(other_parity), kappa); + TwistCloverInv(x(other_parity), tmp, other_parity); } void DiracTwistedCloverPC::createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double, diff --git a/lib/dirac_twisted_mass.cpp b/lib/dirac_twisted_mass.cpp index a6b13706fb..c7909145ce 100644 --- a/lib/dirac_twisted_mass.cpp +++ b/lib/dirac_twisted_mass.cpp @@ -108,10 +108,8 @@ namespace quda { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); } void DiracTwistedMass::reconstruct(cvector_ref &, cvector_ref &, @@ -165,14 +163,14 @@ namespace quda { if (in.TwistFlavor() != out.TwistFlavor()) errorQuda("Twist flavors %d %d don't match", in.TwistFlavor(), out.TwistFlavor()); if (in.TwistFlavor() == QUDA_TWIST_NO || in.TwistFlavor() == QUDA_TWIST_INVALID) - errorQuda("Twist flavor not set %d\n", in.TwistFlavor()); + errorQuda("Twist flavor not set %d", in.TwistFlavor()); if (in.TwistFlavor() == QUDA_TWIST_SINGLET) { double a = -2.0 * kappa * mu; // for inverse twist double b = 1.0 / (1.0 + a * a); bool asymmetric - = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; + = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; ApplyTwistedMassPreconditioned(out, in, *gauge, b, a, false, in, parity, dagger, asymmetric, commDim.data, profile); } else {//TWIST doublet : double a = 2.0 * kappa * mu; @@ -180,7 +178,7 @@ namespace quda { double c = 1.0 / (1.0 + a * a - b * b); bool asymmetric - = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; + = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; ApplyNdegTwistedMassPreconditioned(out, in, *gauge, c, -2.0 * mu * kappa, 2.0 * kappa * epsilon, false, in, parity, dagger, asymmetric, commDim.data, profile); } @@ -200,9 +198,9 @@ namespace quda { if(in.TwistFlavor() == QUDA_TWIST_SINGLET) { double a = -2.0 * kappa * mu; // for inverse twist double b = k / (1.0 + a * a); - // asymmetric should never be true here since we never need to apply 1 + k * A^{-1} D^\dagger + // asymmetric should never be false here since we never need to apply 1 + k * A^{-1} D^\dagger bool asymmetric - = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; + = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; ApplyTwistedMassPreconditioned(out, in, *gauge, b, a, true, x, parity, dagger, asymmetric, commDim.data, profile); } else {//TWIST_DOUBLET: double a = 2.0 * kappa * mu; @@ -210,7 +208,7 @@ namespace quda { double c = 1.0 / (1.0 + a * a - b * b); bool asymmetric - = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; + = (matpcType == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) && dagger; ApplyNdegTwistedMassPreconditioned(out, in, *gauge, k * c, -2 * mu * kappa, 2 * kappa * epsilon, true, x, parity, dagger, asymmetric, commDim.data, profile); } @@ -244,56 +242,50 @@ namespace quda { { // we desire solution to preconditioned system if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { - for (auto i = 0u; i < b.size(); i++) { - src[i] = const_cast(b[i]).create_alias(); - sol[i] = x[i].create_alias(); - } + create_alias(src, b); + create_alias(sol, x); return; } // we desire solution to full system - auto tmp = getFieldTmp(x[0].Even()); - bool symmetric = (matpcType == QUDA_MATPC_EVEN_EVEN || matpcType == QUDA_MATPC_ODD_ODD) ? true : false; + auto tmp = getFieldTmp(x.Even()); + create_alias(src, x(other_parity)); + create_alias(sol, x(this_parity)); - for (auto i = 0u; i < b.size(); i++) { - src[i] = x[i][other_parity].create_alias(); - sol[i] = x[i][this_parity].create_alias(); + TwistInv(symmetric ? src : tmp, b(other_parity)); - TwistInv(symmetric ? src[i] : static_cast(tmp), b[i][other_parity]); + if (b.TwistFlavor() == QUDA_TWIST_SINGLET) { - if (b.TwistFlavor() == QUDA_TWIST_SINGLET) { - - if (symmetric) { - // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) - DiracWilson::DslashXpay(tmp, src[i], this_parity, b[i][this_parity], kappa); - } else { - // src = b_e + k D_eo A_oo^-1 b_o - DiracWilson::DslashXpay(src[i], tmp, this_parity, b[i][this_parity], kappa); - } + if (symmetric) { + // src = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o) + DiracWilson::DslashXpay(tmp, src, this_parity, b(this_parity), kappa); + } else { + // src = b_e + k D_eo A_oo^-1 b_o + DiracWilson::DslashXpay(src, tmp, this_parity, b(this_parity), kappa); + } - } else { // doublet: + } else { // doublet: - // repurpose the preconditioned dslash as a vectorized operator: 1+kappa*D - double mu_ = mu; - mu = 0.0; - double epsilon_ = epsilon; - epsilon = 0.0; + // repurpose the preconditioned dslash as a vectorized operator: 1+kappa*D + double mu_ = mu; + mu = 0.0; + double epsilon_ = epsilon; + epsilon = 0.0; - if (symmetric) { - // src = A_ee^-1(b_e + k D_eo A_oo^-1 b_o) - DslashXpay(tmp, src[i], this_parity, b[i][this_parity], kappa); - } else { - // src = b_e + k D_eo A_oo^-1 b_o - DslashXpay(src[i], tmp, this_parity, b[i][this_parity], kappa); - } + if (symmetric) { + // src = A_ee^-1(b_e + k D_eo A_oo^-1 b_o) + DslashXpay(tmp, src, this_parity, b(this_parity), kappa); + } else { + // src = b_e + k D_eo A_oo^-1 b_o + DslashXpay(src, tmp, this_parity, b(this_parity), kappa); + } - mu = mu_; - epsilon = epsilon_; + mu = mu_; + epsilon = epsilon_; - } // end of doublet + } // end of doublet - if (symmetric) TwistInv(src[i], tmp); - } + if (symmetric) TwistInv(src, tmp); } void DiracTwistedMassPC::reconstruct(cvector_ref &x, cvector_ref &b, @@ -301,29 +293,27 @@ namespace quda { { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) return; - auto tmp = getFieldTmp(x[0].Even()); - for (auto i = 0u; i < b.size(); i++) { - checkFullSpinor(x[i], b[i]); - - // create full solution - if (b.TwistFlavor() == QUDA_TWIST_SINGLET) { - // x_o = A_oo^-1 (b_o + k D_oe x_e) - DiracWilson::DslashXpay(tmp, x[i][this_parity], other_parity, b[i][other_parity], kappa); - } else { // twist doublet: - double mu_ = mu; - mu = 0.0; - double epsilon_ = epsilon; - epsilon = 0.0; - - // x_o = A_oo^-1 (b_o + k D_oe x_e) - DslashXpay(tmp, x[i][this_parity], other_parity, b[i][other_parity], kappa); - - mu = mu_; - epsilon = epsilon_; - } + auto tmp = getFieldTmp(x.Even()); + checkFullSpinor(x, b); + + // create full solution + if (b.TwistFlavor() == QUDA_TWIST_SINGLET) { + // x_o = A_oo^-1 (b_o + k D_oe x_e) + DiracWilson::DslashXpay(tmp, x(this_parity), other_parity, b(other_parity), kappa); + } else { // twist doublet: + double mu_ = mu; + mu = 0.0; + double epsilon_ = epsilon; + epsilon = 0.0; - TwistInv(x[i][other_parity], tmp); + // x_o = A_oo^-1 (b_o + k D_oe x_e) + DslashXpay(tmp, x(this_parity), other_parity, b(other_parity), kappa); + + mu = mu_; + epsilon = epsilon_; } + + TwistInv(x(other_parity), tmp); } void DiracTwistedMassPC::createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double, diff --git a/lib/dirac_wilson.cpp b/lib/dirac_wilson.cpp index a6a5264b8e..c9570e57ee 100644 --- a/lib/dirac_wilson.cpp +++ b/lib/dirac_wilson.cpp @@ -137,6 +137,7 @@ namespace quda { const QudaSolutionType solType) const { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { + // we desire solution to preconditioned system create_alias(src, b); create_alias(sol, x); return; From d4886076f463d5ee9d63d07101070d2679dd4e22 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 3 Jul 2024 18:00:43 -0700 Subject: [PATCH 021/103] Fix bug in GammaApply with introduced in #1416 --- lib/dslash_gamma_helper.cu | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lib/dslash_gamma_helper.cu b/lib/dslash_gamma_helper.cu index 2824776915..9d53ffbd32 100644 --- a/lib/dslash_gamma_helper.cu +++ b/lib/dslash_gamma_helper.cu @@ -32,7 +32,6 @@ namespace quda { void preTune() { out.backup(); } void postTune() { out.restore(); } - long long flops() const { return 0; } long long bytes() const { return out.Bytes() + in.Bytes(); } }; @@ -86,12 +85,11 @@ namespace quda { void apply(const qudaStream_t &stream) { TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); - launch(tp, stream, GammaArg(out, in, d, kappa, mu, epsilon, dagger, type)); + launch(tp, stream, GammaArg(out, in, d, 0, kappa, mu, epsilon, dagger, type)); } void preTune() { out.backup(); } void postTune() { out.restore(); } - long long flops() const { return 0; } long long bytes() const { return out.Bytes() + in.Bytes(); } }; From 6d1bafed5f9fe912060d2711011f9a5222fd1aec Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 3 Jul 2024 18:01:11 -0700 Subject: [PATCH 022/103] Fix issue with CG::hq_solve --- lib/inv_cg_quda.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index 31911b16cd..70f44f3f39 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -678,7 +678,7 @@ namespace quda { auto get_hq_res = [](cvector_ref &x, cvector_ref &r) { auto hq_nrm = blas::HeavyQuarkResidualNorm(x, r); vector hq_res(hq_nrm.size()); - for (auto i = 0u; i < hq_nrm.size(); i++) hq_res[i] = sqrt(hq_res[i]); + for (auto i = 0u; i < hq_nrm.size(); i++) hq_res[i] = sqrt(hq_nrm[i].z); return hq_res; }; @@ -824,7 +824,7 @@ namespace quda { cvector_ref &r) { auto hq_nrm = blas::xpyHeavyQuarkResidualNorm(x, y, r); vector hq_res(hq_nrm.size()); - for (auto i = 0u; i < hq_nrm.size(); i++) hq_res[i] = sqrt(hq_res[i]); + for (auto i = 0u; i < hq_nrm.size(); i++) hq_res[i] = sqrt(hq_nrm[i].z); return hq_res; }; From 902d8ab47c0574f8f55e5c2a31afe74fff2f11bc Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 9 Jul 2024 16:34:50 -0700 Subject: [PATCH 023/103] Fix bug with Clover Hasenbsusch operator (wrong braces) --- lib/dirac_clover_hasenbusch_twist.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/dirac_clover_hasenbusch_twist.cpp b/lib/dirac_clover_hasenbusch_twist.cpp index fe08ac765a..079f2ec5b9 100644 --- a/lib/dirac_clover_hasenbusch_twist.cpp +++ b/lib/dirac_clover_hasenbusch_twist.cpp @@ -29,14 +29,14 @@ namespace quda void DiracCloverHasenbuschTwist::M(cvector_ref &out, cvector_ref &in) const { if (symmetric) { - ApplyWilsonCloverHasenbuschTwist(out[this_parity], in[other_parity], *gauge, *clover, -kappa, mu, in[this_parity], + ApplyWilsonCloverHasenbuschTwist(out(this_parity), in(other_parity), *gauge, *clover, -kappa, mu, in(this_parity), this_parity, dagger, commDim.data, profile); - ApplyWilsonClover(out[other_parity], in[this_parity], *gauge, *clover, -kappa, in[other_parity], other_parity, + ApplyWilsonClover(out(other_parity), in(this_parity), *gauge, *clover, -kappa, in(other_parity), other_parity, dagger, commDim.data, profile); } else { - ApplyWilsonClover(out[other_parity], in[this_parity], *gauge, *clover, -kappa, in[other_parity], other_parity, + ApplyWilsonClover(out(other_parity), in(this_parity), *gauge, *clover, -kappa, in(other_parity), other_parity, dagger, commDim.data, profile); - ApplyTwistedClover(out[this_parity], in[other_parity], *gauge, *clover, -kappa, mu, in[this_parity], this_parity, + ApplyTwistedClover(out(this_parity), in(other_parity), *gauge, *clover, -kappa, mu, in(this_parity), this_parity, dagger, commDim.data, profile); } } From 02eecaa9eaa869b353935d22b58059c7e204320e Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 10 Jul 2024 23:18:01 -0700 Subject: [PATCH 024/103] Fix bug with DiracCoarsePC::reconstruct when using odd solve --- lib/dirac_coarse.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/dirac_coarse.cpp b/lib/dirac_coarse.cpp index 0136b29ac3..55e827b0ca 100644 --- a/lib/dirac_coarse.cpp +++ b/lib/dirac_coarse.cpp @@ -513,11 +513,11 @@ namespace quda { QudaFieldLocation location = checkLocation(out[0], in[0]); initializeLazy(location); - if ( location == QUDA_CUDA_FIELD_LOCATION) { + if (location == QUDA_CUDA_FIELD_LOCATION) { auto Y = apply_mma(out, dslash_use_mma) ? Yhat_aos_d : Yhat_d; auto X = apply_mma(out, dslash_use_mma) ? X_aos_d : X_d; ApplyCoarse(out, in, in, *Y, *X, kappa, parity, true, false, dagger, commDim.data, halo_precision, dslash_use_mma); - } else if ( location == QUDA_CPU_FIELD_LOCATION ) { + } else if (location == QUDA_CPU_FIELD_LOCATION) { ApplyCoarse(out, in, in, *Yhat_h, *X_h, kappa, parity, true, false, dagger, commDim.data, halo_precision, dslash_use_mma); } @@ -618,9 +618,9 @@ namespace quda { auto tmp = getFieldTmp(x.Even()); #if 1 // x_o = A_oo^-1 (b_o - D_oe x_e) - DiracCoarse::Dslash(tmp, x.Even(), QUDA_ODD_PARITY); - blas::xpay(b.Odd(), -1.0, tmp); - CloverInv(x.Odd(), tmp, QUDA_ODD_PARITY); + DiracCoarse::Dslash(tmp, x(this_parity), other_parity); + blas::xpay(b(other_parity), -1.0, tmp); + CloverInv(x(other_parity), tmp, other_parity); #else // x_o = A_oo^{-1} b_o - (A_oo^{-1} D_oe) x_e Dslash(tmp, x(this_parity), other_parity); From 8b067d4efe530d8d3520ca6de2461faed9147cf5 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 10 Jul 2024 23:25:45 -0700 Subject: [PATCH 025/103] Fix bug with counting bytes with clover operator --- lib/dslash_clover_helper.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/dslash_clover_helper.cu b/lib/dslash_clover_helper.cu index 0754191850..d8923a7044 100644 --- a/lib/dslash_clover_helper.cu +++ b/lib/dslash_clover_helper.cu @@ -38,7 +38,7 @@ namespace quda { long long flops() const { return in.size() * in.Volume() * 504ll; } - long long bytes() const { return in.size() * (out.Bytes() + in.Bytes() + clover.Bytes() / (3 - in.SiteSubset())); } + long long bytes() const { return out.Bytes() + in.Bytes() + clover.Bytes() / (3 - in.SiteSubset()); } }; //Apply the clover matrix field to a colorspinor field @@ -123,10 +123,10 @@ namespace quda { long long flops() const { return in.size() * (inverse ? 1056ll : 552ll) * in.Volume(); } long long bytes() const { - long long rtn = out.Bytes() + in.Bytes() + clover.Bytes() / (3 - in.SiteSubset()); + long long rtn = out.Bytes() + in.Bytes() + in.size() * clover.Bytes() / (3 - in.SiteSubset()); if (twist == QUDA_TWIST_GAMMA5_INVERSE && !clover::dynamic_inverse()) - rtn += clover.Bytes() / (3 - in.SiteSubset()); - return in.size() * rtn; + rtn += in.size() * clover.Bytes() / (3 - in.SiteSubset()); + return rtn; } }; From c682cae9288d15ccbf375f8ff74a9a2893cb286c Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 10 Jul 2024 23:46:26 -0700 Subject: [PATCH 026/103] Default inner GCR solver to use L2 residual to enable early exit if possible --- lib/solver.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/solver.cpp b/lib/solver.cpp index ee7278e041..b4f84da600 100644 --- a/lib/solver.cpp +++ b/lib/solver.cpp @@ -215,9 +215,8 @@ namespace quda { errorQuda("Unexpected preconditioned solver %d", outer.inv_type); } - // this sets a fixed iteration count if we're using the MR solver - inner.residual_type - = (outer.inv_type_precondition == QUDA_MR_INVERTER) ? QUDA_INVALID_RESIDUAL : QUDA_L2_RELATIVE_RESIDUAL; + // allows the inner solver to early exit if it converges quickly + inner.residual_type = QUDA_L2_RELATIVE_RESIDUAL; inner.iter = 0; inner.inv_type_precondition = QUDA_INVALID_INVERTER; From 1c5baefd3974c4cfe96bc097e531d6f3b7d281a4 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Sun, 21 Jul 2024 02:49:36 -0700 Subject: [PATCH 027/103] Initial work to prepare for multi-rhs solver exposure: move the body of invertQuda to new function solve which is MRHS aware --- lib/CMakeLists.txt | 2 +- lib/interface_quda.cpp | 328 +++-------------------------------------- lib/multigrid.cpp | 1 - lib/solve.cpp | 315 +++++++++++++++++++++++++++++++++++++++ lib/timer.cpp | 3 + 5 files changed, 338 insertions(+), 311 deletions(-) create mode 100644 lib/solve.cpp diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 4d8639b9c9..0b08bb37f4 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -17,7 +17,7 @@ endif() set (QUDA_OBJS # cmake-format: sortable - monitor.cpp dirac_coarse.cpp dslash_coarse.cpp + solve.cpp monitor.cpp dirac_coarse.cpp dslash_coarse.cpp coarse_op.cpp coarsecoarse_op.cpp coarse_op_preconditioned.cpp staggered_coarse_op.cpp eig_iram.cpp eig_trlm.cpp eig_block_trlm.cpp vector_io.cpp diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 8de817fa21..1f47aff52e 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -297,6 +297,13 @@ static void profilerStop(const char *f) { namespace quda { void printLaunchTimer(); + + void massRescale(cvector_ref &b, QudaInvertParam ¶m, bool for_multishift); + + void distanceReweight(cvector_ref &b, QudaInvertParam ¶m, bool inverse); + + void solve(cvector_ref &x, cvector_ref &b, Dirac &dirac, Dirac &diracSloppy, + Dirac &diracPre, Dirac &diracEig, QudaInvertParam ¶m); } void setVerbosityQuda(QudaVerbosity verbosity, const char prefix[], FILE *outfile) @@ -1703,109 +1710,6 @@ namespace quda { dEig = Dirac::create(diracEigParam); } - void massRescale(ColorSpinorField &b, QudaInvertParam ¶m, bool for_multishift) - { - double kappa5 = (0.5/(5.0 + param.m5)); - double kappa = (param.dslash_type == QUDA_DOMAIN_WALL_DSLASH || param.dslash_type == QUDA_DOMAIN_WALL_4D_DSLASH - || param.dslash_type == QUDA_MOBIUS_DWF_DSLASH || param.dslash_type == QUDA_MOBIUS_DWF_EOFA_DSLASH) ? - kappa5 : - param.kappa; - - logQuda(QUDA_DEBUG_VERBOSE, "Mass rescale: Kappa is: %g\n", kappa); - logQuda(QUDA_DEBUG_VERBOSE, "Mass rescale: mass normalization: %d\n", param.mass_normalization); - logQuda(QUDA_DEBUG_VERBOSE, "Mass rescale: norm of source in = %g\n", blas::norm2(b)); - - // staggered dslash uses mass normalization internally - if (param.dslash_type == QUDA_ASQTAD_DSLASH || param.dslash_type == QUDA_STAGGERED_DSLASH) { - switch (param.solution_type) { - case QUDA_MAT_SOLUTION: - case QUDA_MATPC_SOLUTION: - if (param.mass_normalization == QUDA_KAPPA_NORMALIZATION) blas::ax(2.0*param.mass, b); - break; - case QUDA_MATDAG_MAT_SOLUTION: - case QUDA_MATPCDAG_MATPC_SOLUTION: - if (param.mass_normalization == QUDA_KAPPA_NORMALIZATION) blas::ax(4.0*param.mass*param.mass, b); - break; - default: - errorQuda("Not implemented"); - } - return; - } - - // multiply the source to compensate for normalization of the Dirac operator, if necessary - // you are responsible for restoring what's in param.offset - switch (param.solution_type) { - case QUDA_MAT_SOLUTION: - if (param.mass_normalization == QUDA_MASS_NORMALIZATION || - param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { - blas::ax(2.0*kappa, b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 2.0 * kappa; - } - break; - case QUDA_MATDAG_MAT_SOLUTION: - if (param.mass_normalization == QUDA_MASS_NORMALIZATION || - param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { - blas::ax(4.0*kappa*kappa, b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; - } - break; - case QUDA_MATPC_SOLUTION: - if (param.mass_normalization == QUDA_MASS_NORMALIZATION) { - blas::ax(4.0*kappa*kappa, b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; - } else if (param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { - blas::ax(2.0*kappa, b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 2.0 * kappa; - } - break; - case QUDA_MATPCDAG_MATPC_SOLUTION: - if (param.mass_normalization == QUDA_MASS_NORMALIZATION) { - blas::ax(16.0*std::pow(kappa,4), b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 16.0 * std::pow(kappa, 4); - } else if (param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { - blas::ax(4.0*kappa*kappa, b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; - } - break; - default: - errorQuda("Solution type %d not supported", param.solution_type); - } - - logQuda(QUDA_DEBUG_VERBOSE, "Mass rescale: norm of source out = %g\n", blas::norm2(b)); - } -} - -void distanceReweight(ColorSpinorField &b, QudaInvertParam ¶m, bool inverse) -{ - // Force the alpha0 to be positive. - // A negative alpha0 matches something like Eq.(12) in arXiv:1006.4028. - // Disable the negative situation as QUDA already has multigrid for light quarks. - const double alpha0 = abs(param.distance_pc_alpha0); - const int t0 = param.distance_pc_t0; - if (alpha0 != 0.0 && t0 >= 0) { - if (param.dslash_type != QUDA_WILSON_DSLASH && param.dslash_type != QUDA_CLOVER_WILSON_DSLASH) { - errorQuda("Only Wilson and Wilson-clover dslash support distance preconditioning, but get dslash_type %d\n", - param.dslash_type); - } - if (param.inv_type == QUDA_MG_INVERTER) { - errorQuda("Multigrid solver doesn't support distance preconditioning\n"); - } - if (param.cuda_prec != QUDA_DOUBLE_PRECISION || param.cuda_prec_sloppy != QUDA_DOUBLE_PRECISION) { - warningQuda( - "Using single or half (sloppy) precision in distance preconditioning sometimes makes the solver diverge"); - } - - if (inverse) - spinorDistanceReweight(b, -alpha0, t0); - else - spinorDistanceReweight(b, alpha0, t0); - } } void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, QudaParity parity) @@ -1840,7 +1744,7 @@ void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, QudaParity in = in_h; - profileDslash.TPSTART(QUDA_PROFILE_COMPUTE); + getProfile().TPSTART(QUDA_PROFILE_COMPUTE); logQuda(QUDA_DEBUG_VERBOSE, "In CPU %e CUDA %e\n", blas::norm2(in_h), blas::norm2(in)); @@ -1872,7 +1776,7 @@ void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, QudaParity } else { dirac->Dslash(out, in, parity); // apply the operator } - profileDslash.TPSTOP(QUDA_PROFILE_COMPUTE); + getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); distanceReweight(out, *inv_param, false); @@ -2796,9 +2700,6 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam delete dEig; popVerbosity(); - - // cache is written out even if a long benchmarking job gets interrupted - saveTuneCache(); } multigrid_solver::multigrid_solver(QudaMultigridParam &mg_param) @@ -2876,9 +2777,6 @@ multigrid_solver::multigrid_solver(QudaMultigridParam &mg_param) mg = new MG(*mgParam); mgParam->updateInvertParam(*param); - - // cache is written out even if a long benchmarking job gets interrupted - saveTuneCache(); } void *newMultigridQuda(QudaMultigridParam *mg_param) @@ -2889,8 +2787,6 @@ void *newMultigridQuda(QudaMultigridParam *mg_param) auto *mg = new multigrid_solver(*mg_param); - saveTuneCache(); - popVerbosity(); profilerStop(__func__); return static_cast(mg); @@ -3003,9 +2899,6 @@ void updateMultigridQuda(void *mg_, QudaMultigridParam *mg_param) setOutputPrefix(""); - // cache is written out even if a long benchmarking job gets interrupted - saveTuneCache(); - profileInvert.TPSTOP(QUDA_PROFILE_PREAMBLE); popVerbosity(); @@ -3124,28 +3017,17 @@ void invertQuda(void *hp_x, void *hp_b, QudaInvertParam *param) (param->solution_type == QUDA_MATPCDAG_MATPC_SOLUTION); bool pc_solve = (param->solve_type == QUDA_DIRECT_PC_SOLVE) || (param->solve_type == QUDA_NORMOP_PC_SOLVE) || (param->solve_type == QUDA_NORMERR_PC_SOLVE); - bool mat_solution = (param->solution_type == QUDA_MAT_SOLUTION) || - (param->solution_type == QUDA_MATPC_SOLUTION); - bool direct_solve = (param->solve_type == QUDA_DIRECT_SOLVE) || - (param->solve_type == QUDA_DIRECT_PC_SOLVE); - bool norm_error_solve = (param->solve_type == QUDA_NORMERR_SOLVE) || - (param->solve_type == QUDA_NORMERR_PC_SOLVE); param->iter = 0; - Dirac *d = nullptr; - Dirac *dSloppy = nullptr; - Dirac *dPre = nullptr; - Dirac *dEig = nullptr; + Dirac *dirac = nullptr; + Dirac *diracSloppy = nullptr; + Dirac *diracPre = nullptr; + Dirac *diracEig = nullptr; // Create the dirac operator and operators for sloppy, precondition, // and an eigensolver - createDiracWithEig(d, dSloppy, dPre, dEig, *param, pc_solve); - - Dirac &dirac = *d; - Dirac &diracSloppy = *dSloppy; - Dirac &diracPre = *dPre; - Dirac &diracEig = *dEig; + createDiracWithEig(dirac, diracSloppy, diracPre, diracEig, *param, pc_solve); // wrap CPU host side pointers ColorSpinorParam cpuParam(hp_b, *param, cudaGauge->X(), pc_solution, param->input_location); @@ -3193,181 +3075,16 @@ void invertQuda(void *hp_x, void *hp_b, QudaInvertParam *param) blas::zero(x); } - // if we're doing a managed memory MG solve and prefetching is - // enabled, prefetch all the Dirac matrices. There's probably - // a better place to put this... - if (param->inv_type_precondition == QUDA_MG_INVERTER) { - dirac.prefetch(QUDA_CUDA_FIELD_LOCATION); - diracSloppy.prefetch(QUDA_CUDA_FIELD_LOCATION); - diracPre.prefetch(QUDA_CUDA_FIELD_LOCATION); - } - - profileInvert.TPSTART(QUDA_PROFILE_PREAMBLE); - - double nb = blas::norm2(b); - if (nb==0.0) errorQuda("Source has zero norm"); - logQuda(QUDA_VERBOSE, "Source: %g\n", nb); - if (param->use_init_guess == QUDA_USE_INIT_GUESS_YES) logQuda(QUDA_VERBOSE, "Initial guess: %g\n", blas::norm2(x)); - - // rescale the source and solution vectors to help prevent the onset of underflow - if (param->solver_normalization == QUDA_SOURCE_NORMALIZATION) { - blas::ax(1.0 / sqrt(nb), b); - blas::ax(1.0 / sqrt(nb), x); - } - - massRescale(b, *param, false); - distanceReweight(b, *param, true); - - ColorSpinorField in; - ColorSpinorField out; - dirac.prepare(out, in, x, b, param->solution_type); - - logQuda(QUDA_VERBOSE, "Prepared source = %g\n", blas::norm2(in)); - logQuda(QUDA_VERBOSE, "Prepared solution = %g\n", blas::norm2(out)); - - // solution_type specifies *what* system is to be solved. - // solve_type specifies *how* the system is to be solved. - // - // We have the following four cases (plus preconditioned variants): - // - // solution_type solve_type Effect - // ------------- ---------- ------ - // MAT DIRECT Solve Ax=b - // MATDAG_MAT DIRECT Solve A^dag y = b, followed by Ax=y - // MAT NORMOP Solve (A^dag A) x = (A^dag b) - // MATDAG_MAT NORMOP Solve (A^dag A) x = b - // MAT NORMERR Solve (A A^dag) y = b, then x = A^dag y - // - // We generally require that the solution_type and solve_type - // preconditioning match. As an exception, the unpreconditioned MAT - // solution_type may be used with any solve_type, including - // DIRECT_PC and NORMOP_PC. In these cases, preparation of the - // preconditioned source and reconstruction of the full solution are - // taken care of by Dirac::prepare() and Dirac::reconstruct(), - // respectively. - - profileInvert.TPSTOP(QUDA_PROFILE_PREAMBLE); - - if (mat_solution && !direct_solve && !norm_error_solve) { // prepare source: b' = A^dag b - ColorSpinorField tmp(in); - dirac.Mdag(in, tmp); - } else if (!mat_solution && direct_solve) { // perform the first of two solves: A^dag y = b - DiracMdag m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); - SolverParam solverParam(*param); - Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); - (*solve)(out, in); - blas::copy(in, out); - delete solve; - solverParam.updateInvertParam(*param); - } - - if (direct_solve) { - DiracM m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); - SolverParam solverParam(*param); - - // chronological forecasting - if (param->chrono_use_resident && chronoResident[param->chrono_index].size() > 0) { - bool hermitian = false; - auto &mChrono = param->chrono_precision == param->cuda_prec ? m : mSloppy; - chronoExtrapolate(out, in, chronoResident[param->chrono_index], mChrono, hermitian); - } - - Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); - (*solve)(out, in); - delete solve; - solverParam.updateInvertParam(*param); - } else if (!norm_error_solve) { - DiracMdagM m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); - SolverParam solverParam(*param); - - // chronological forecasting - if (param->chrono_use_resident && chronoResident[param->chrono_index].size() > 0) { - bool hermitian = true; - auto &mChrono = param->chrono_precision == param->cuda_prec ? m : mSloppy; - chronoExtrapolate(out, in, chronoResident[param->chrono_index], mChrono, hermitian); - } - - // if using a Schwarz preconditioner with a normal operator then we must use the DiracMdagMLocal operator - if (param->inv_type_precondition != QUDA_INVALID_INVERTER && param->schwarz_type != QUDA_INVALID_SCHWARZ) { - DiracMdagMLocal mPreLocal(diracPre); - Solver *solve = Solver::create(solverParam, m, mSloppy, mPreLocal, mEig); - (*solve)(out, in); - delete solve; - solverParam.updateInvertParam(*param); - } else { - Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); - (*solve)(out, in); - delete solve; - solverParam.updateInvertParam(*param); - } - } else { // norm_error_solve - DiracMMdag m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); - ColorSpinorField tmp(out); - SolverParam solverParam(*param); - Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); - (*solve)(tmp, in); // y = (M M^\dag) b - dirac.Mdag(out, tmp); // x = M^dag y - delete solve; - solverParam.updateInvertParam(*param); - } - - logQuda(QUDA_VERBOSE, "Solution = %g\n", blas::norm2(x)); - - profileInvert.TPSTART(QUDA_PROFILE_EPILOGUE); - if (param->chrono_make_resident) { - const int i = param->chrono_index; - if (i >= QUDA_MAX_CHRONO) - errorQuda("Requested chrono index %d is outside of max %d\n", i, QUDA_MAX_CHRONO); - - auto &basis = chronoResident[i]; - - if (param->chrono_max_dim < (int)basis.size()) { - errorQuda("Requested chrono_max_dim %i is smaller than already existing chronology %lu", param->chrono_max_dim, basis.size()); - } - - if(not param->chrono_replace_last){ - // if we have not filled the space yet just augment - if ((int)basis.size() < param->chrono_max_dim) { - ColorSpinorParam cs_param(out); - cs_param.setPrecision(param->chrono_precision); - basis.emplace_back(cs_param); - } - - // shuffle every entry down one and bring the last to the front - std::rotate(basis.begin(), basis.end() - 1, basis.end()); - } - basis[0] = out; // set first entry to new solution - } - dirac.reconstruct(x, b, param->solution_type); - - distanceReweight(x, *param, false); - - if (param->solver_normalization == QUDA_SOURCE_NORMALIZATION) { - // rescale the solution - blas::ax(sqrt(nb), x); - } - - if (param->compute_action) { - Complex action = blas::cDotProduct(b, x); - param->action[0] = action.real(); - param->action[1] = action.imag(); - } - - profileInvert.TPSTOP(QUDA_PROFILE_EPILOGUE); + solve(x, b, *dirac, *diracSloppy, *diracPre, *diracEig, *param); if (!param->make_resident_solution) h_x = x; - logQuda(QUDA_VERBOSE, "Reconstructed solution: %g\n", blas::norm2(x)); - if (param->use_resident_solution && !param->make_resident_solution) solutionResident.clear(); - delete d; - delete dSloppy; - delete dPre; - delete dEig; - - // cache is written out even if a long benchmarking job gets interrupted - saveTuneCache(); + delete dirac; + delete diracSloppy; + delete diracPre; + delete diracEig; profilerStop(__func__); popVerbosity(); @@ -4084,9 +3801,6 @@ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) delete dPre; delete dRefine; - // cache is written out even if a long benchmarking job gets interrupted - saveTuneCache(); - profilerStop(__func__); popVerbosity(); } @@ -5468,8 +5182,6 @@ void performTwoLinkGaussianSmearNStep(void *h_in, QudaQuarkSmearParam *smear_par delete d; if (smear_param->delete_2link != 0) { freeUniqueGaugeQuda(QUDA_SMEARED_LINKS); } - - saveTuneCache(); } void performGaugeSmearQuda(QudaGaugeSmearParam *smear_param, QudaGaugeObservableParam *obs_param) @@ -5743,8 +5455,6 @@ void contractFTQuda(void **prop_array_flavor_1, void **prop_array_flavor_2, void } } profileContractFT.TPSTOP(QUDA_PROFILE_COMPUTE); - - saveTuneCache(); } void contractQuda(const void *hp_x, const void *hp_y, void *h_result, const QudaContractType cType, diff --git a/lib/multigrid.cpp b/lib/multigrid.cpp index b07a007238..9a239e37b2 100644 --- a/lib/multigrid.cpp +++ b/lib/multigrid.cpp @@ -1691,7 +1691,6 @@ namespace quda } logQuda(QUDA_VERBOSE, "Done building free vectors\n"); - popLevel(); } diff --git a/lib/solve.cpp b/lib/solve.cpp new file mode 100644 index 0000000000..d3f03464b6 --- /dev/null +++ b/lib/solve.cpp @@ -0,0 +1,315 @@ +#include "invert_quda.h" + +namespace quda { + + // vector of spinors used for forecasting solutions in HMC +#define QUDA_MAX_CHRONO 12 + // each entry is one p + std::vector> chronoResident(QUDA_MAX_CHRONO); + + void massRescale(cvector_ref &b, QudaInvertParam ¶m, bool for_multishift) + { + double kappa5 = (0.5/(5.0 + param.m5)); + double kappa = (param.dslash_type == QUDA_DOMAIN_WALL_DSLASH || param.dslash_type == QUDA_DOMAIN_WALL_4D_DSLASH + || param.dslash_type == QUDA_MOBIUS_DWF_DSLASH || param.dslash_type == QUDA_MOBIUS_DWF_EOFA_DSLASH) ? + kappa5 : + param.kappa; + + logQuda(QUDA_DEBUG_VERBOSE, "Mass rescale: Kappa is: %g\n", kappa); + logQuda(QUDA_DEBUG_VERBOSE, "Mass rescale: mass normalization: %d\n", param.mass_normalization); + if (getVerbosity() > QUDA_DEBUG_VERBOSE) { + auto b2 = blas::norm2(b); + for (auto &b2i : b2) printfQuda("Mass rescale: norm of source in = %g\n", b2i); + } + + // staggered dslash uses mass normalization internally + if (param.dslash_type == QUDA_ASQTAD_DSLASH || param.dslash_type == QUDA_STAGGERED_DSLASH) { + switch (param.solution_type) { + case QUDA_MAT_SOLUTION: + case QUDA_MATPC_SOLUTION: + if (param.mass_normalization == QUDA_KAPPA_NORMALIZATION) blas::ax(2.0*param.mass, b); + break; + case QUDA_MATDAG_MAT_SOLUTION: + case QUDA_MATPCDAG_MATPC_SOLUTION: + if (param.mass_normalization == QUDA_KAPPA_NORMALIZATION) blas::ax(4.0*param.mass*param.mass, b); + break; + default: + errorQuda("Not implemented"); + } + return; + } + + // multiply the source to compensate for normalization of the Dirac operator, if necessary + // you are responsible for restoring what's in param.offset + switch (param.solution_type) { + case QUDA_MAT_SOLUTION: + if (param.mass_normalization == QUDA_MASS_NORMALIZATION || + param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { + blas::ax(2.0*kappa, b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 2.0 * kappa; + } + break; + case QUDA_MATDAG_MAT_SOLUTION: + if (param.mass_normalization == QUDA_MASS_NORMALIZATION || + param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { + blas::ax(4.0*kappa*kappa, b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; + } + break; + case QUDA_MATPC_SOLUTION: + if (param.mass_normalization == QUDA_MASS_NORMALIZATION) { + blas::ax(4.0*kappa*kappa, b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; + } else if (param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { + blas::ax(2.0*kappa, b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 2.0 * kappa; + } + break; + case QUDA_MATPCDAG_MATPC_SOLUTION: + if (param.mass_normalization == QUDA_MASS_NORMALIZATION) { + blas::ax(16.0*std::pow(kappa,4), b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 16.0 * std::pow(kappa, 4); + } else if (param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { + blas::ax(4.0*kappa*kappa, b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; + } + break; + default: + errorQuda("Solution type %d not supported", param.solution_type); + } + + if (getVerbosity() > QUDA_DEBUG_VERBOSE) { + auto b2 = blas::norm2(b); + for (auto &b2i : b2) printfQuda("Mass rescale: norm of source out = %g\n", b2i); + } + } + + void distanceReweight(cvector_ref &b, QudaInvertParam ¶m, bool inverse) + { + // Force the alpha0 to be positive. + // A negative alpha0 matches something like Eq.(12) in arXiv:1006.4028. + // Disable the negative situation as QUDA already has multigrid for light quarks. + const double alpha0 = abs(param.distance_pc_alpha0); + const int t0 = param.distance_pc_t0; + if (alpha0 != 0.0 && t0 >= 0) { + if (param.dslash_type != QUDA_WILSON_DSLASH && param.dslash_type != QUDA_CLOVER_WILSON_DSLASH) { + errorQuda("Only Wilson and Wilson-clover dslash support distance preconditioning, but get dslash_type %d\n", + param.dslash_type); + } + if (param.inv_type == QUDA_MG_INVERTER) errorQuda("Multigrid solver doesn't support distance preconditioning"); + + if (param.cuda_prec != QUDA_DOUBLE_PRECISION || param.cuda_prec_sloppy != QUDA_DOUBLE_PRECISION) { + warningQuda( + "Using single or half (sloppy) precision in distance preconditioning sometimes makes the solver diverge"); + } + + if (inverse) + for (auto i = 0u; i < b.size(); i++) spinorDistanceReweight(b[i], -alpha0, t0); + else + for (auto i = 0u; i < b.size(); i++) spinorDistanceReweight(b[i], alpha0, t0); + } + } + + void solve(cvector_ref &x, cvector_ref &b, + Dirac &dirac, Dirac &diracSloppy, Dirac &diracPre, Dirac &diracEig, + QudaInvertParam ¶m) + { + getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); + + bool mat_solution = (param.solution_type == QUDA_MAT_SOLUTION) || + (param.solution_type == QUDA_MATPC_SOLUTION); + bool direct_solve = (param.solve_type == QUDA_DIRECT_SOLVE) || + (param.solve_type == QUDA_DIRECT_PC_SOLVE); + bool norm_error_solve = (param.solve_type == QUDA_NORMERR_SOLVE) || + (param.solve_type == QUDA_NORMERR_PC_SOLVE); + + auto nb = blas::norm2(b); + for (auto &bi : nb) { + if (bi == 0.0) errorQuda("Source has zero norm"); + logQuda(QUDA_VERBOSE, "Source: %g\n", bi); + } + if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { + auto x_norm = blas::norm2(x); + for (auto &xi : x_norm) logQuda(QUDA_VERBOSE, "Initial guess: %g\n", xi); + } + // rescale the source and solution vectors to help prevent the onset of underflow + if (param.solver_normalization == QUDA_SOURCE_NORMALIZATION) { + auto nb_inv(nb); + for (auto bi : nb_inv) bi = 1 / sqrt(bi); + blas::ax(nb_inv, b); + blas::ax(nb_inv, x); + } + + massRescale(b, param, false); + distanceReweight(b, param, true); + + std::vector in(b.size()); + std::vector out(b.size()); + + // if we're doing a managed memory MG solve and prefetching is + // enabled, prefetch all the Dirac matrices. There's probably + // a better place to put this... + if (param.inv_type_precondition == QUDA_MG_INVERTER) { + dirac.prefetch(QUDA_CUDA_FIELD_LOCATION); + diracSloppy.prefetch(QUDA_CUDA_FIELD_LOCATION); + diracPre.prefetch(QUDA_CUDA_FIELD_LOCATION); + } + + dirac.prepare(out, in, x, b, param.solution_type); + + if (getVerbosity() >= QUDA_VERBOSE) { + auto in_norm = blas::norm2(in); + auto out_norm = blas::norm2(out); + for (auto i = 0u; i < in.size(); i++) + logQuda(QUDA_VERBOSE, "Prepared: source = %g, solution = %g\n", in_norm[i], out_norm[i]); + } + + // solution_type specifies *what* system is to be solved. + // solve_type specifies *how* the system is to be solved. + // + // We have the following four cases (plus preconditioned variants): + // + // solution_type solve_type Effect + // ------------- ---------- ------ + // MAT DIRECT Solve Ax=b + // MATDAG_MAT DIRECT Solve A^dag y = b, followed by Ax=y + // MAT NORMOP Solve (A^dag A) x = (A^dag b) + // MATDAG_MAT NORMOP Solve (A^dag A) x = b + // MAT NORMERR Solve (A A^dag) y = b, then x = A^dag y + // + // We generally require that the solution_type and solve_type + // preconditioning match. As an exception, the unpreconditioned MAT + // solution_type may be used with any solve_type, including + // DIRECT_PC and NORMOP_PC. In these cases, preparation of the + // preconditioned source and reconstruction of the full solution are + // taken care of by Dirac::prepare() and Dirac::reconstruct(), + // respectively. + + getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); + + if (mat_solution && !direct_solve && !norm_error_solve) { // prepare source: b' = A^dag b + auto tmp = getFieldTmp(cvector_ref(in)); + blas::copy(tmp, in); + dirac.Mdag(in, tmp); + } else if (!mat_solution && direct_solve) { // perform the first of two solves: A^dag y = b + DiracMdag m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); + SolverParam solverParam(param); + Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); + (*solve)(out, in); + blas::copy(in, out); + delete solve; + solverParam.updateInvertParam(param); + } + + if (direct_solve) { + DiracM m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); + SolverParam solverParam(param); + + // chronological forecasting + if (param.chrono_use_resident && chronoResident[param.chrono_index].size() > 0) { + bool hermitian = false; + auto &mChrono = param.chrono_precision == param.cuda_prec ? m : mSloppy; + chronoExtrapolate(out[0], in[0], chronoResident[param.chrono_index], mChrono, hermitian); + } + + Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); + (*solve)(out, in); + delete solve; + solverParam.updateInvertParam(param); + } else if (!norm_error_solve) { + DiracMdagM m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); + SolverParam solverParam(param); + + // chronological forecasting + if (param.chrono_use_resident && chronoResident[param.chrono_index].size() > 0) { + bool hermitian = true; + auto &mChrono = param.chrono_precision == param.cuda_prec ? m : mSloppy; + chronoExtrapolate(out[0], in[0], chronoResident[param.chrono_index], mChrono, hermitian); + } + + // if using a Schwarz preconditioner with a normal operator then we must use the DiracMdagMLocal operator + if (param.inv_type_precondition != QUDA_INVALID_INVERTER && param.schwarz_type != QUDA_INVALID_SCHWARZ) { + DiracMdagMLocal mPreLocal(diracPre); + Solver *solve = Solver::create(solverParam, m, mSloppy, mPreLocal, mEig); + (*solve)(out, in); + delete solve; + solverParam.updateInvertParam(param); + } else { + Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); + (*solve)(out, in); + delete solve; + solverParam.updateInvertParam(param); + } + } else { // norm_error_solve + DiracMMdag m(dirac), mSloppy(diracSloppy), mPre(diracPre), mEig(diracEig); + auto tmp = getFieldTmp(cvector_ref(in)); + SolverParam solverParam(param); + Solver *solve = Solver::create(solverParam, m, mSloppy, mPre, mEig); + (*solve)(tmp, in); // y = (M M^\dag) b + dirac.Mdag(out, tmp); // x = M^dag y + delete solve; + solverParam.updateInvertParam(param); + } + + if (getVerbosity() >= QUDA_VERBOSE) { + auto x_norm = blas::norm2(out); + for (auto i = 0u; i < x.size(); i++) printfQuda("Solution = %g\n", x_norm[i]); + } + + getProfile().TPSTART(QUDA_PROFILE_EPILOGUE); + if (param.chrono_make_resident) { + const int i = param.chrono_index; + if (i >= QUDA_MAX_CHRONO) + errorQuda("Requested chrono index %d is outside of max %d\n", i, QUDA_MAX_CHRONO); + + auto &basis = chronoResident[i]; + + if (param.chrono_max_dim < (int)basis.size()) { + errorQuda("Requested chrono_max_dim %i is smaller than already existing chronology %lu", param.chrono_max_dim, basis.size()); + } + + if(not param.chrono_replace_last){ + // if we have not filled the space yet just augment + if ((int)basis.size() < param.chrono_max_dim) { + ColorSpinorParam cs_param(out[0]); + cs_param.setPrecision(param.chrono_precision); + basis.emplace_back(cs_param); + } + + // shuffle every entry down one and bring the last to the front + std::rotate(basis.begin(), basis.end() - 1, basis.end()); + } + basis[0] = out[0]; // set first entry to new solution + } + + dirac.reconstruct(x, b, param.solution_type); + + distanceReweight(x, param, false); + + if (param.solver_normalization == QUDA_SOURCE_NORMALIZATION) { + // rescale the solution + for (auto bi : nb) bi = sqrt(bi); + blas::ax(nb, x); + } + + if (getVerbosity() >= QUDA_VERBOSE) { + auto x_norm = blas::norm2(x); + for (auto i = 0u; i < x.size(); i++) printfQuda("Reconstructed Solution = %g\n", x_norm[i]); + } + + if (param.compute_action) { + auto action = blas::cDotProduct(b, x); + param.action[0] = action[0].real(); + param.action[1] = action[0].imag(); + } + + getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); + } + +} diff --git a/lib/timer.cpp b/lib/timer.cpp index f8ed6a17c7..6a122da4fd 100644 --- a/lib/timer.cpp +++ b/lib/timer.cpp @@ -263,6 +263,9 @@ namespace quda { secs = profile.Last(QUDA_PROFILE_TOTAL); gflops = (Tunable::flops_global() - flops) * 1e-9; if (&gflops != &gflops_dummy) comm_allreduce_sum(gflops); + + // cache is written out even if a long benchmarking job gets interrupted + saveTuneCache(); } } From faf4658efc66dc2881e615416c8c1a42b2e3185d Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Sun, 21 Jul 2024 02:50:03 -0700 Subject: [PATCH 028/103] Fix flops counters for blas and reduce functions --- lib/blas_quda.cu | 2 +- lib/reduce_quda.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/blas_quda.cu b/lib/blas_quda.cu index c76ae11881..1b5f552bca 100644 --- a/lib/blas_quda.cu +++ b/lib/blas_quda.cu @@ -130,7 +130,7 @@ namespace quda { return location == QUDA_CPU_FIELD_LOCATION ? false : Tunable::advanceTuneParam(param); } - long long flops() const override { return f.flops() * x.Length(); } + long long flops() const override { return f.flops() * x.Length() * x.size(); } long long bytes() const override { return (f.read.X + f.write.X) * x.Bytes() + (f.read.Y + f.write.Y) * y.Bytes() + diff --git a/lib/reduce_quda.cu b/lib/reduce_quda.cu index 61790d2ec5..58e2e9b199 100644 --- a/lib/reduce_quda.cu +++ b/lib/reduce_quda.cu @@ -135,7 +135,7 @@ namespace quda { if (r.write.V) v.restore(); } - long long flops() const override { return r.flops() * x.Length(); } + long long flops() const override { return r.flops() * x.Length() * x.size(); } long long bytes() const override { From b2f9849d12e1957ef4fe3ae2fb1ba1a48eeccb0b Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Sun, 21 Jul 2024 03:07:16 -0700 Subject: [PATCH 029/103] Move remainder of invertQuda body into new MRHS solve wrapper that is also now called by the invertMultiSrcQuda interface --- lib/interface_quda.cpp | 122 ++++++---------------- lib/solve.cpp | 231 ++++++++++++++++++++++++++++------------- 2 files changed, 193 insertions(+), 160 deletions(-) diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 1f47aff52e..75771e1ebb 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -108,7 +108,12 @@ CloverField *cloverEigensolver = nullptr; GaugeField momResident; GaugeField *extendedGaugeResident = nullptr; -std::vector solutionResident; +namespace quda +{ + + std::vector solutionResident; + +} // vector of spinors used for forecasting solutions in HMC #define QUDA_MAX_CHRONO 12 @@ -302,8 +307,8 @@ namespace quda { void distanceReweight(cvector_ref &b, QudaInvertParam ¶m, bool inverse); - void solve(cvector_ref &x, cvector_ref &b, Dirac &dirac, Dirac &diracSloppy, - Dirac &diracPre, Dirac &diracEig, QudaInvertParam ¶m); + void solve(const std::vector &hp_x, const std::vector &hp_b, QudaInvertParam ¶m, + const GaugeField &u); } void setVerbosityQuda(QudaVerbosity verbosity, const char prefix[], FILE *outfile) @@ -3009,83 +3014,10 @@ void invertQuda(void *hp_x, void *hp_b, QudaInvertParam *param) // check the gauge fields have been created GaugeField *cudaGauge = checkGauge(param); - // It was probably a bad design decision to encode whether the system is even/odd preconditioned (PC) in - // solve_type and solution_type, rather than in separate members of QudaInvertParam. We're stuck with it - // for now, though, so here we factorize everything for convenience. - - bool pc_solution = (param->solution_type == QUDA_MATPC_SOLUTION) || - (param->solution_type == QUDA_MATPCDAG_MATPC_SOLUTION); - bool pc_solve = (param->solve_type == QUDA_DIRECT_PC_SOLVE) || - (param->solve_type == QUDA_NORMOP_PC_SOLVE) || (param->solve_type == QUDA_NORMERR_PC_SOLVE); - - param->iter = 0; - - Dirac *dirac = nullptr; - Dirac *diracSloppy = nullptr; - Dirac *diracPre = nullptr; - Dirac *diracEig = nullptr; - - // Create the dirac operator and operators for sloppy, precondition, - // and an eigensolver - createDiracWithEig(dirac, diracSloppy, diracPre, diracEig, *param, pc_solve); - - // wrap CPU host side pointers - ColorSpinorParam cpuParam(hp_b, *param, cudaGauge->X(), pc_solution, param->input_location); - ColorSpinorField h_b(cpuParam); - - cpuParam.v = hp_x; - cpuParam.location = param->output_location; - ColorSpinorField h_x(cpuParam); - - // download source - ColorSpinorParam cudaParam(cpuParam, *param, QUDA_CUDA_FIELD_LOCATION); - cudaParam.create = QUDA_COPY_FIELD_CREATE; - cudaParam.field = &h_b; - ColorSpinorField b(cudaParam); - - // now check if we need to invalidate the solutionResident vectors - ColorSpinorField x; - if (param->use_resident_solution == 1) { - for (auto &v : solutionResident) { - if (b.Precision() != v.Precision() || b.SiteSubset() != v.SiteSubset()) { - solutionResident.clear(); - break; - } - } - - if (!solutionResident.size()) { - cudaParam.create = QUDA_NULL_FIELD_CREATE; - solutionResident = std::vector(1, cudaParam); - } - x = solutionResident[0].create_alias(cudaParam); - } else { - cudaParam.create = QUDA_NULL_FIELD_CREATE; - x = ColorSpinorField(cudaParam); - } - - if (param->use_init_guess == QUDA_USE_INIT_GUESS_YES && !param->chrono_use_resident) { // download initial guess - // initial guess only supported for single-pass solvers - if ((param->solution_type == QUDA_MATDAG_MAT_SOLUTION || param->solution_type == QUDA_MATPCDAG_MATPC_SOLUTION) && - (param->solve_type == QUDA_DIRECT_SOLVE || param->solve_type == QUDA_DIRECT_PC_SOLVE)) { - errorQuda("Initial guess not supported for two-pass solver"); - } - - x = h_x; // solution - } else { // zero initial guess - blas::zero(x); - } - - solve(x, b, *dirac, *diracSloppy, *diracPre, *diracEig, *param); - - if (!param->make_resident_solution) h_x = x; + solve({hp_x}, {hp_b}, *param, *cudaGauge); if (param->use_resident_solution && !param->make_resident_solution) solutionResident.clear(); - delete dirac; - delete diracSloppy; - delete diracPre; - delete diracEig; - profilerStop(__func__); popVerbosity(); } @@ -3165,7 +3097,10 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col if (num_sub_partition == 1) { // In this case we don't split the grid. - for (int n = 0; n < param->num_src; n++) { op(_hp_x[n], _hp_b[n], param, args...); } + std::vector x(param->num_src), b(param->num_src); + for (auto i = 0u; i < x.size(); i++) x[i] = _hp_x[i]; + for (auto i = 0u; i < b.size(); i++) b[i] = _hp_b[i]; + op(x, b, *param, args...); } else { @@ -3381,17 +3316,19 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col // Since input fields are in Native order now param_copy.dirac_order = QUDA_INTERNAL_DIRAC_ORDER; - // We need to set the cpu_prec in the param_copy, because the op() passed in - // to us will try to create wrappers to the pointers we pass in. They expect - // the input spinors to be on the host, and will use param_copy.cpu_prec to set - // the precision. We want to avoid the situation, where the internal prec and the - // cpu_prec are somehow different. - param_copy.cpu_prec = _collect_b[0].Precision(); + // We need to set the cpu_prec in the param_copy, because the op() passed in + // to us will try to create wrappers to the pointers we pass in. They expect + // the input spinors to be on the host, and will use param_copy.cpu_prec to set + // the precision. We want to avoid the situation, where the internal prec and the + // cpu_prec are somehow different. + param_copy.cpu_prec = _collect_b[0].Precision(); // Do the solves - for (int n = 0; n < param->num_src_per_sub_partition; n++) { - op(_collect_x[n].data(), _collect_b[n].data(), ¶m_copy, args...); - } + std::vector x_raw(param->num_src_per_sub_partition); + std::vector b_raw(param->num_src_per_sub_partition); + for (auto i = 0u; i < x_raw.size(); i++) x_raw[i] = _collect_x[i].data(); + for (auto i = 0u; i < b_raw.size(); i++) b_raw[i] = _collect_b[i].data(); + op(x_raw, b_raw, param_copy, args...); profileInvertMultiSrc.TPSTART(QUDA_PROFILE_EPILOGUE); push_communicator(default_comm_key); @@ -3457,14 +3394,19 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col void invertMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param) { - auto op = [](void *_x, void *_b, QudaInvertParam *param) { invertQuda(_x, _b, param); }; + auto op = [](const std::vector &_x, const std::vector &_b, QudaInvertParam ¶m) { + // check the gauge fields have been created + GaugeField *gauge = checkGauge(¶m); + solve(_x, _b, param, *gauge); + }; callMultiSrcQuda(_hp_x, _hp_b, param, op); } - void dslashMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, QudaParity parity) { - auto op = [](void *_x, void *_b, QudaInvertParam *param, QudaParity parity) { dslashQuda(_x, _b, param, parity); }; + auto op = [](const std::vector &_x, const std::vector &_b, QudaInvertParam ¶m, QudaParity parity) { + for (auto i = 0u; i < _b.size(); i++) dslashQuda(_x[i], _b[i], ¶m, parity); + }; callMultiSrcQuda(_hp_x, _hp_b, param, op, parity); } diff --git a/lib/solve.cpp b/lib/solve.cpp index d3f03464b6..c0a2ac422b 100644 --- a/lib/solve.cpp +++ b/lib/solve.cpp @@ -1,15 +1,16 @@ #include "invert_quda.h" -namespace quda { +namespace quda +{ // vector of spinors used for forecasting solutions in HMC #define QUDA_MAX_CHRONO 12 // each entry is one p std::vector> chronoResident(QUDA_MAX_CHRONO); - + void massRescale(cvector_ref &b, QudaInvertParam ¶m, bool for_multishift) { - double kappa5 = (0.5/(5.0 + param.m5)); + double kappa5 = (0.5 / (5.0 + param.m5)); double kappa = (param.dslash_type == QUDA_DOMAIN_WALL_DSLASH || param.dslash_type == QUDA_DOMAIN_WALL_4D_DSLASH || param.dslash_type == QUDA_MOBIUS_DWF_DSLASH || param.dslash_type == QUDA_MOBIUS_DWF_EOFA_DSLASH) ? kappa5 : @@ -25,16 +26,15 @@ namespace quda { // staggered dslash uses mass normalization internally if (param.dslash_type == QUDA_ASQTAD_DSLASH || param.dslash_type == QUDA_STAGGERED_DSLASH) { switch (param.solution_type) { - case QUDA_MAT_SOLUTION: - case QUDA_MATPC_SOLUTION: - if (param.mass_normalization == QUDA_KAPPA_NORMALIZATION) blas::ax(2.0*param.mass, b); - break; - case QUDA_MATDAG_MAT_SOLUTION: - case QUDA_MATPCDAG_MATPC_SOLUTION: - if (param.mass_normalization == QUDA_KAPPA_NORMALIZATION) blas::ax(4.0*param.mass*param.mass, b); - break; - default: - errorQuda("Not implemented"); + case QUDA_MAT_SOLUTION: + case QUDA_MATPC_SOLUTION: + if (param.mass_normalization == QUDA_KAPPA_NORMALIZATION) blas::ax(2.0 * param.mass, b); + break; + case QUDA_MATDAG_MAT_SOLUTION: + case QUDA_MATPCDAG_MATPC_SOLUTION: + if (param.mass_normalization == QUDA_KAPPA_NORMALIZATION) blas::ax(4.0 * param.mass * param.mass, b); + break; + default: errorQuda("Not implemented"); } return; } @@ -42,46 +42,45 @@ namespace quda { // multiply the source to compensate for normalization of the Dirac operator, if necessary // you are responsible for restoring what's in param.offset switch (param.solution_type) { - case QUDA_MAT_SOLUTION: - if (param.mass_normalization == QUDA_MASS_NORMALIZATION || - param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { - blas::ax(2.0*kappa, b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 2.0 * kappa; - } - break; - case QUDA_MATDAG_MAT_SOLUTION: - if (param.mass_normalization == QUDA_MASS_NORMALIZATION || - param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { - blas::ax(4.0*kappa*kappa, b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; - } - break; - case QUDA_MATPC_SOLUTION: - if (param.mass_normalization == QUDA_MASS_NORMALIZATION) { - blas::ax(4.0*kappa*kappa, b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; - } else if (param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { - blas::ax(2.0*kappa, b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 2.0 * kappa; - } - break; - case QUDA_MATPCDAG_MATPC_SOLUTION: - if (param.mass_normalization == QUDA_MASS_NORMALIZATION) { - blas::ax(16.0*std::pow(kappa,4), b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 16.0 * std::pow(kappa, 4); - } else if (param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { - blas::ax(4.0*kappa*kappa, b); - if (for_multishift) - for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; - } - break; - default: - errorQuda("Solution type %d not supported", param.solution_type); + case QUDA_MAT_SOLUTION: + if (param.mass_normalization == QUDA_MASS_NORMALIZATION + || param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { + blas::ax(2.0 * kappa, b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 2.0 * kappa; + } + break; + case QUDA_MATDAG_MAT_SOLUTION: + if (param.mass_normalization == QUDA_MASS_NORMALIZATION + || param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { + blas::ax(4.0 * kappa * kappa, b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; + } + break; + case QUDA_MATPC_SOLUTION: + if (param.mass_normalization == QUDA_MASS_NORMALIZATION) { + blas::ax(4.0 * kappa * kappa, b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; + } else if (param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { + blas::ax(2.0 * kappa, b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 2.0 * kappa; + } + break; + case QUDA_MATPCDAG_MATPC_SOLUTION: + if (param.mass_normalization == QUDA_MASS_NORMALIZATION) { + blas::ax(16.0 * std::pow(kappa, 4), b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 16.0 * std::pow(kappa, 4); + } else if (param.mass_normalization == QUDA_ASYMMETRIC_MASS_NORMALIZATION) { + blas::ax(4.0 * kappa * kappa, b); + if (for_multishift) + for (int i = 0; i < param.num_offset; i++) param.offset[i] *= 4.0 * kappa * kappa; + } + break; + default: errorQuda("Solution type %d not supported", param.solution_type); } if (getVerbosity() > QUDA_DEBUG_VERBOSE) { @@ -106,7 +105,7 @@ namespace quda { if (param.cuda_prec != QUDA_DOUBLE_PRECISION || param.cuda_prec_sloppy != QUDA_DOUBLE_PRECISION) { warningQuda( - "Using single or half (sloppy) precision in distance preconditioning sometimes makes the solver diverge"); + "Using single or half (sloppy) precision in distance preconditioning sometimes makes the solver diverge"); } if (inverse) @@ -116,18 +115,14 @@ namespace quda { } } - void solve(cvector_ref &x, cvector_ref &b, - Dirac &dirac, Dirac &diracSloppy, Dirac &diracPre, Dirac &diracEig, - QudaInvertParam ¶m) + void solve(cvector_ref &x, cvector_ref &b, Dirac &dirac, Dirac &diracSloppy, + Dirac &diracPre, Dirac &diracEig, QudaInvertParam ¶m) { getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); - bool mat_solution = (param.solution_type == QUDA_MAT_SOLUTION) || - (param.solution_type == QUDA_MATPC_SOLUTION); - bool direct_solve = (param.solve_type == QUDA_DIRECT_SOLVE) || - (param.solve_type == QUDA_DIRECT_PC_SOLVE); - bool norm_error_solve = (param.solve_type == QUDA_NORMERR_SOLVE) || - (param.solve_type == QUDA_NORMERR_PC_SOLVE); + bool mat_solution = (param.solution_type == QUDA_MAT_SOLUTION) || (param.solution_type == QUDA_MATPC_SOLUTION); + bool direct_solve = (param.solve_type == QUDA_DIRECT_SOLVE) || (param.solve_type == QUDA_DIRECT_PC_SOLVE); + bool norm_error_solve = (param.solve_type == QUDA_NORMERR_SOLVE) || (param.solve_type == QUDA_NORMERR_PC_SOLVE); auto nb = blas::norm2(b); for (auto &bi : nb) { @@ -258,23 +253,23 @@ namespace quda { } if (getVerbosity() >= QUDA_VERBOSE) { - auto x_norm = blas::norm2(out); + auto x_norm = blas::norm2(out); for (auto i = 0u; i < x.size(); i++) printfQuda("Solution = %g\n", x_norm[i]); } getProfile().TPSTART(QUDA_PROFILE_EPILOGUE); if (param.chrono_make_resident) { const int i = param.chrono_index; - if (i >= QUDA_MAX_CHRONO) - errorQuda("Requested chrono index %d is outside of max %d\n", i, QUDA_MAX_CHRONO); + if (i >= QUDA_MAX_CHRONO) errorQuda("Requested chrono index %d is outside of max %d\n", i, QUDA_MAX_CHRONO); auto &basis = chronoResident[i]; if (param.chrono_max_dim < (int)basis.size()) { - errorQuda("Requested chrono_max_dim %i is smaller than already existing chronology %lu", param.chrono_max_dim, basis.size()); + errorQuda("Requested chrono_max_dim %i is smaller than already existing chronology %lu", param.chrono_max_dim, + basis.size()); } - if(not param.chrono_replace_last){ + if (not param.chrono_replace_last) { // if we have not filled the space yet just augment if ((int)basis.size() < param.chrono_max_dim) { ColorSpinorParam cs_param(out[0]); @@ -299,7 +294,7 @@ namespace quda { } if (getVerbosity() >= QUDA_VERBOSE) { - auto x_norm = blas::norm2(x); + auto x_norm = blas::norm2(x); for (auto i = 0u; i < x.size(); i++) printfQuda("Reconstructed Solution = %g\n", x_norm[i]); } @@ -312,4 +307,100 @@ namespace quda { getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); } -} + void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam ¶m, + const bool pc_solve); + + extern std::vector solutionResident; + + void solve(const std::vector &hp_x, const std::vector &hp_b, QudaInvertParam ¶m, + const GaugeField &u) + { + if (hp_b.size() != hp_x.size()) + errorQuda("Number of solutions %lu != number of solves %lu", hp_x.size(), hp_b.size()); + int n_src = hp_b.size(); + + // It was probably a bad design decision to encode whether the system is even/odd preconditioned (PC) in + // solve_type and solution_type, rather than in separate members of QudaInvertParam. We're stuck with it + // for now, though, so here we factorize everything for convenience. + + bool pc_solution + = (param.solution_type == QUDA_MATPC_SOLUTION) || (param.solution_type == QUDA_MATPCDAG_MATPC_SOLUTION); + bool pc_solve = (param.solve_type == QUDA_DIRECT_PC_SOLVE) || (param.solve_type == QUDA_NORMOP_PC_SOLVE) + || (param.solve_type == QUDA_NORMERR_PC_SOLVE); + + param.iter = 0; + + Dirac *dirac = nullptr; + Dirac *diracSloppy = nullptr; + Dirac *diracPre = nullptr; + Dirac *diracEig = nullptr; + + // Create the dirac operator and operators for sloppy, precondition, + // and an eigensolver + createDiracWithEig(dirac, diracSloppy, diracPre, diracEig, param, pc_solve); + + // wrap CPU host side pointers + ColorSpinorParam cpuParam(hp_b[0], param, u.X(), pc_solution, param.input_location); + std::vector h_b(n_src); + for (auto i = 0u; i < h_b.size(); i++) { + cpuParam.v = hp_b[i]; + h_b[i] = ColorSpinorField(cpuParam); + } + + std::vector h_x(n_src); + cpuParam.location = param.output_location; + for (auto i = 0u; i < h_x.size(); i++) { + cpuParam.v = hp_x[i]; + h_x[i] = ColorSpinorField(cpuParam); + } + + // download source + ColorSpinorParam cudaParam(cpuParam, param, QUDA_CUDA_FIELD_LOCATION); + cudaParam.create = QUDA_NULL_FIELD_CREATE; + std::vector b; + resize(b, n_src, cudaParam); + blas::copy(b, h_b); + + // now check if we need to invalidate the solutionResident vectors + std::vector x; + resize(x, n_src, cudaParam); + if (param.use_resident_solution == 1) { + for (auto &v : solutionResident) { + if (b[0].Precision() != v.Precision() || b[0].SiteSubset() != v.SiteSubset()) { + solutionResident.clear(); + break; + } + } + + if (!solutionResident.size()) { + cudaParam.create = QUDA_NULL_FIELD_CREATE; + solutionResident = std::vector(1, cudaParam); + } + x[0] = solutionResident[0].create_alias(cudaParam); + } else { + cudaParam.create = QUDA_NULL_FIELD_CREATE; + x[0] = ColorSpinorField(cudaParam); + } + + if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES && !param.chrono_use_resident) { // download initial guess + // initial guess only supported for single-pass solvers + if ((param.solution_type == QUDA_MATDAG_MAT_SOLUTION || param.solution_type == QUDA_MATPCDAG_MATPC_SOLUTION) + && (param.solve_type == QUDA_DIRECT_SOLVE || param.solve_type == QUDA_DIRECT_PC_SOLVE)) { + errorQuda("Initial guess not supported for two-pass solver"); + } + + blas::copy(x, h_x); // solution + } else { // zero initial guess + blas::zero(x); + } + + solve(x, b, *dirac, *diracSloppy, *diracPre, *diracEig, param); + + if (!param.make_resident_solution) blas::copy(h_x, x); + + delete dirac; + delete diracSloppy; + delete diracPre; + delete diracEig; + } +} // namespace quda From c51f6e64bcfc1685edc32af3f30ba2b60120e4af Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Thu, 25 Jul 2024 02:14:22 -0700 Subject: [PATCH 030/103] Fix true residual computation: QudaInvertParam::true_res and QudaInvertParam::true_res_hq are now arrays of length QUDA_MAX_MULTI_SRC, allowing us to return the true residual for each MRHS solve independently. Nore this is a breaking change in ther interface --- include/invert_quda.h | 8 +++--- include/quda.h | 4 +-- include/quda_constants.h | 10 ++------ lib/check_params.h | 5 +++- lib/interface_quda.cpp | 4 +-- lib/milc_interface.cpp | 27 ++++++++++---------- lib/quda_fortran.F90 | 4 +-- tests/deflated_invert_test.cpp | 2 +- tests/host_reference/dslash_reference.cpp | 31 ++++++++++++----------- tests/host_reference/dslash_reference.h | 17 +++++++------ tests/invert_test.cpp | 9 ++++--- tests/multigrid_evolve_test.cpp | 1 + tests/staggered_invert_test.cpp | 4 +-- 13 files changed, 64 insertions(+), 62 deletions(-) diff --git a/include/invert_quda.h b/include/invert_quda.h index 4520712644..b49ea28d1f 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -307,8 +307,8 @@ namespace quda { tol_restart(param.tol_restart), tol_hq(param.tol_hq), compute_true_res(param.compute_true_res), - true_res(param.true_res), - true_res_hq(param.true_res_hq), + true_res(param.num_src, 0.0), + true_res_hq(param.num_src, 0.0), maxiter(param.maxiter), iter(param.iter), precision(param.cuda_prec), @@ -374,8 +374,8 @@ namespace quda { @param param the QudaInvertParam to be updated */ void updateInvertParam(QudaInvertParam ¶m, int offset=-1) { - param.true_res = static_cast(true_res); - param.true_res_hq = static_cast(true_res_hq); + for (auto i = 0u; i < true_res.size(); i++) param.true_res[i] = true_res[i]; + for (auto i = 0u; i < true_res_hq.size(); i++) param.true_res_hq[i] = true_res_hq[i]; param.iter += iter; if (offset >= 0) { param.true_res_offset[offset] = true_res_offset[offset]; diff --git a/include/quda.h b/include/quda.h index 6f720d19b6..2737e192a6 100644 --- a/include/quda.h +++ b/include/quda.h @@ -145,8 +145,8 @@ extern "C" { double tol_hq; /**< Solver tolerance in the heavy quark residual norm */ int compute_true_res; /** Whether to compute the true residual post solve */ - double true_res; /**< Actual L2 residual norm achieved in solver */ - double true_res_hq; /**< Actual heavy quark residual norm achieved in solver */ + double true_res[QUDA_MAX_MULTI_SRC]; /**< Actual L2 residual norm achieved in the solver */ + double true_res_hq[QUDA_MAX_MULTI_SRC]; /**< Actual heavy quark residual norm achieved in the solver */ int maxiter; /**< Maximum number of iterations in the linear solver */ double reliable_delta; /**< Reliable update tolerance */ double reliable_delta_refinement; /**< Reliable update tolerance used in post multi-shift solver refinement */ diff --git a/include/quda_constants.h b/include/quda_constants.h index 983a49a0aa..983795cba7 100644 --- a/include/quda_constants.h +++ b/include/quda_constants.h @@ -32,15 +32,9 @@ /** * @def QUDA_MAX_BLOCK_SRC - * @brief Maximum number of sources that can be supported by the block solver + * @brief Maximum number of sources that can be supported by the multi-src solver */ -#define QUDA_MAX_BLOCK_SRC 64 - -/** - * @def QUDA_MAX_ARRAY - * @brief Maximum array length used in QudaInvertParam arrays - */ -#define QUDA_MAX_ARRAY_SIZE (QUDA_MAX_MULTI_SHIFT > QUDA_MAX_BLOCK_SRC ? QUDA_MAX_MULTI_SHIFT : QUDA_MAX_BLOCK_SRC) +#define QUDA_MAX_MULTI_SRC 128 /** * @def QUDA_MAX_DWF_LS diff --git a/lib/check_params.h b/lib/check_params.h index 082cc56af4..8a9f93c81b 100644 --- a/lib/check_params.h +++ b/lib/check_params.h @@ -437,8 +437,11 @@ void printQudaInvertParam(QudaInvertParam *param) { #ifndef CHECK_PARAM P(pipeline, 0); /** Whether to use a pipelined solver */ P(num_offset, 0); /**< Number of offsets in the multi-shift solver */ - P(num_src, 1); /**< Number of offsets in the multi-shift solver */ + P(num_src, 1); /**< Number of sources to solve for simultaneously */ P(overlap, 0); /**< width of domain overlaps */ +#else + if (param->num_src > QUDA_MAX_MULTI_SRC) + errorQuda("num_src %d exceeds limit of %d", param->num_src, QUDA_MAX_MULTI_SRC); #endif #ifdef INIT_PARAM diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 75771e1ebb..9de4b4aa50 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -3095,6 +3095,8 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col errorQuda("split_key = [%d,%d,%d,%d] is not valid", split_key[0], split_key[1], split_key[2], split_key[3]); } + checkInvertParam(param, _hp_x[0], _hp_b[0]); + if (num_sub_partition == 1) { // In this case we don't split the grid. std::vector x(param->num_src), b(param->num_src); @@ -3118,8 +3120,6 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col // Doesn't work for MG yet. if (param->inv_type_precondition == QUDA_MG_INVERTER) errorQuda("Split Grid does NOT work with MG yet"); - checkInvertParam(param, _hp_x[0], _hp_b[0]); - bool is_staggered = false; bool is_asqtad = false; diff --git a/lib/milc_interface.cpp b/lib/milc_interface.cpp index 131c57fde0..99d0c1b89c 100644 --- a/lib/milc_interface.cpp +++ b/lib/milc_interface.cpp @@ -1209,8 +1209,8 @@ void qudaInvert(int external_precision, int quda_precision, double mass, QudaInv // return the number of iterations taken by the inverter *num_iters = invertParam.iter; - *final_residual = invertParam.true_res; - *final_fermilab_residual = invertParam.true_res_hq; + *final_residual = invertParam.true_res[0]; + *final_fermilab_residual = invertParam.true_res_hq[0]; if (!create_quda_gauge) invalidateGaugeQuda(); @@ -1426,8 +1426,9 @@ void qudaInvertMsrc(int external_precision, int quda_precision, double mass, Qud // return the number of iterations taken by the inverter *num_iters = invertParam.iter; - *final_residual = invertParam.true_res; - *final_fermilab_residual = invertParam.true_res_hq; + // FIXME MILC seems to only care about a single residual? + *final_residual = invertParam.true_res[0]; + *final_fermilab_residual = invertParam.true_res_hq[0]; if (!create_quda_gauge) invalidateGaugeQuda(); @@ -1521,8 +1522,8 @@ void qudaEigCGInvert(int external_precision, int quda_precision, double mass, Qu // return the number of iterations taken by the inverter *num_iters = invertParam.iter; - *final_residual = invertParam.true_res; - *final_fermilab_residual = invertParam.true_res_hq; + *final_residual = invertParam.true_res[0]; + *final_fermilab_residual = invertParam.true_res_hq[0]; if (!create_quda_gauge && last_rhs_flag) invalidateGaugeQuda(); @@ -2707,8 +2708,8 @@ void qudaInvertMG(int external_precision, int quda_precision, double mass, QudaI // return the number of iterations taken by the inverter *num_iters = invertParam.iter; - *final_residual = invertParam.true_res; - *final_fermilab_residual = invertParam.true_res_hq; + *final_residual = invertParam.true_res[0]; + *final_fermilab_residual = invertParam.true_res_hq[0]; if (!create_quda_gauge) invalidateGaugeQuda(); @@ -3004,8 +3005,8 @@ void qudaCloverInvert(int external_precision, invertQuda(solution, source, &invertParam); *num_iters = invertParam.iter; - *final_residual = invertParam.true_res; - *final_fermilab_residual = invertParam.true_res_hq; + *final_residual = invertParam.true_res[0]; + *final_fermilab_residual = invertParam.true_res_hq[0]; if (clover || cloverInverse) qudaFreeCloverField(); if (link) qudaFreeGaugeField(); @@ -3083,8 +3084,8 @@ void qudaEigCGCloverInvert(int external_precision, int quda_precision, double ka if (last_rhs_flag) destroyDeflationQuda(df_preconditioner); *num_iters = invertParam.iter; - *final_residual = invertParam.true_res; - *final_fermilab_residual = invertParam.true_res_hq; + *final_residual = invertParam.true_res[0]; + *final_fermilab_residual = invertParam.true_res_hq[0]; if ( (clover || cloverInverse) && last_rhs_flag) qudaFreeCloverField(); if (link && last_rhs_flag) qudaFreeGaugeField(); @@ -3151,7 +3152,7 @@ void qudaCloverMultishiftInvert(int external_precision, int quda_precision, int } invertQuda(solutionArray[0], source, &invertParam); - *final_residual = invertParam.true_res; + *final_residual = invertParam.true_res[0]; } else { invertMultiShiftQuda(solutionArray, source, &invertParam); for (int i=0; i verifyInversion(void *spinorOut, void *spinorIn, void *spinorCheck, QudaGaugeParam &gauge_param, - QudaInvertParam &inv_param, void **gauge, void *clover, void *clover_inv) + QudaInvertParam &inv_param, void **gauge, void *clover, void *clover_inv, + int src_idx) { void **spinorOutMulti = nullptr; return verifyInversion(spinorOut, spinorOutMulti, spinorIn, spinorCheck, gauge_param, inv_param, gauge, clover, - clover_inv); + clover_inv, src_idx); } std::array verifyInversion(void *spinorOut, void **spinorOutMulti, void *spinorIn, void *spinorCheck, QudaGaugeParam &gauge_param, QudaInvertParam &inv_param, void **gauge, - void *clover, void *clover_inv) + void *clover, void *clover_inv, int src_idx) { std::array res = {std::numeric_limits::max(), std::numeric_limits::max()}; if (dslash_type == QUDA_DOMAIN_WALL_DSLASH || dslash_type == QUDA_DOMAIN_WALL_4D_DSLASH || dslash_type == QUDA_MOBIUS_DWF_DSLASH || dslash_type == QUDA_MOBIUS_DWF_EOFA_DSLASH) { res = verifyDomainWallTypeInversion(spinorOut, spinorOutMulti, spinorIn, spinorCheck, gauge_param, inv_param, gauge, - clover, clover_inv); + clover, clover_inv, src_idx); } else if (dslash_type == QUDA_WILSON_DSLASH || dslash_type == QUDA_CLOVER_WILSON_DSLASH || dslash_type == QUDA_TWISTED_MASS_DSLASH || dslash_type == QUDA_TWISTED_CLOVER_DSLASH) { res = verifyWilsonTypeInversion(spinorOut, spinorOutMulti, spinorIn, spinorCheck, gauge_param, inv_param, gauge, - clover, clover_inv); + clover, clover_inv, src_idx); } else { errorQuda("Unsupported dslash_type=%s", get_dslash_str(dslash_type)); } @@ -40,7 +41,7 @@ std::array verifyInversion(void *spinorOut, void **spinorOutMulti, vo std::array verifyDomainWallTypeInversion(void *spinorOut, void **, void *spinorIn, void *spinorCheck, QudaGaugeParam &gauge_param, QudaInvertParam &inv_param, - void **gauge, void *, void *) + void **gauge, void *, void *, int src_idx) { if (multishift > 1) errorQuda("Multishift not supported"); @@ -163,15 +164,15 @@ std::array verifyDomainWallTypeInversion(void *spinorOut, void **, vo double l2r = sqrt(nrm2 / src2); printfQuda("Residuals: (L2 relative) tol %9.6e, QUDA = %9.6e, host = %9.6e; (heavy-quark) tol %9.6e, QUDA = %9.6e\n", - inv_param.tol, inv_param.true_res, l2r, inv_param.tol_hq, inv_param.true_res_hq); + inv_param.tol, inv_param.true_res[src_idx], l2r, inv_param.tol_hq, inv_param.true_res_hq[src_idx]); return {l2r, inv_param.tol_hq}; ; } -std::array verifyWilsonTypeInversion(void *spinorOut, void **spinorOutMulti, void *spinorIn, - void *spinorCheck, QudaGaugeParam &gauge_param, - QudaInvertParam &inv_param, void **gauge, void *clover, void *clover_inv) +std::array verifyWilsonTypeInversion(void *spinorOut, void **spinorOutMulti, void *spinorIn, void *spinorCheck, + QudaGaugeParam &gauge_param, QudaInvertParam &inv_param, void **gauge, + void *clover, void *clover_inv, int src_idx) { int vol = (inv_param.solution_type == QUDA_MAT_SOLUTION || inv_param.solution_type == QUDA_MATDAG_MAT_SOLUTION ? V : Vh); @@ -409,7 +410,7 @@ std::array verifyWilsonTypeInversion(void *spinorOut, void **spinorOu printfQuda( "Residuals: (L2 relative) tol %9.6e, QUDA = %9.6e, host = %9.6e; (heavy-quark) tol %9.6e, QUDA = %9.6e\n", - inv_param.tol, inv_param.true_res, l2r, inv_param.tol_hq, inv_param.true_res_hq); + inv_param.tol, inv_param.true_res[src_idx], l2r, inv_param.tol_hq, inv_param.true_res_hq[src_idx]); } return {l2r_max, inv_param.tol_hq}; @@ -745,17 +746,17 @@ double verifyWilsonTypeSingularVector(void *spinor_left, void *spinor_right, dou std::array verifyStaggeredInversion(quda::ColorSpinorField &in, quda::ColorSpinorField &out, quda::GaugeField &fat_link, quda::GaugeField &long_link, - QudaInvertParam &inv_param) + QudaInvertParam &inv_param, int src_idx) { std::vector out_vector(1); out_vector[0] = out; - return verifyStaggeredInversion(in, out_vector, fat_link, long_link, inv_param); + return verifyStaggeredInversion(in, out_vector, fat_link, long_link, inv_param, src_idx); } std::array verifyStaggeredInversion(quda::ColorSpinorField &in, std::vector &out_vector, quda::GaugeField &fat_link, quda::GaugeField &long_link, - QudaInvertParam &inv_param) + QudaInvertParam &inv_param, int src_idx) { int dagger = inv_param.dagger == QUDA_DAG_YES ? 1 : 0; double l2r_max = 0.0; @@ -834,7 +835,7 @@ std::array verifyStaggeredInversion(quda::ColorSpinorField &in, printfQuda("Residuals: (L2 relative) tol %9.6e, QUDA = %9.6e, host = %9.6e; (heavy-quark) tol %9.6e, QUDA = %9.6e, " "host = %9.6e\n", - inv_param.tol, inv_param.true_res, l2r, inv_param.tol_hq, inv_param.true_res_hq, hqr); + inv_param.tol, inv_param.true_res[src_idx], l2r, inv_param.tol_hq, inv_param.true_res_hq[src_idx], hqr); l2r_max = l2r; hqr_max = hqr; diff --git a/tests/host_reference/dslash_reference.h b/tests/host_reference/dslash_reference.h index cf7f176c4c..83b8e5934e 100644 --- a/tests/host_reference/dslash_reference.h +++ b/tests/host_reference/dslash_reference.h @@ -87,16 +87,17 @@ static inline void su3Tmul(sFloat *res, const gFloat *mat, const sFloat *vec) } std::array verifyInversion(void *spinorOut, void *spinorIn, void *spinorCheck, QudaGaugeParam &gauge_param, - QudaInvertParam &inv_param, void **gauge, void *clover, void *clover_inv); + QudaInvertParam &inv_param, void **gauge, void *clover, void *clover_inv, + int src_idx); std::array verifyInversion(void *spinorOut, void **spinorOutMulti, void *spinorIn, void *spinorCheck, QudaGaugeParam &gauge_param, QudaInvertParam &inv_param, void **gauge, - void *clover, void *clover_inv); + void *clover, void *clover_inv, int src_idx = 0); std::array verifyDomainWallTypeInversion(void *spinorOut, void **spinorOutMulti, void *spinorIn, void *spinorCheck, QudaGaugeParam &gauge_param, QudaInvertParam &inv_param, void **gauge, void *clover, - void *clover_inv); + void *clover_inv, int src_idx); double verifyWilsonTypeEigenvector(void *spinor, double _Complex lambda, int i, QudaGaugeParam &gauge_param, QudaEigParam &eig_param, void **gauge, void *clover, void *clover_inv); @@ -105,9 +106,9 @@ double verifyWilsonTypeSingularVector(void *spinor_left, void *spinor_right, dou QudaGaugeParam &gauge_param, QudaEigParam &eig_param, void **gauge, void *clover, void *clover_inv); -std::array verifyWilsonTypeInversion(void *spinorOut, void **spinorOutMulti, void *spinorIn, - void *spinorCheck, QudaGaugeParam &gauge_param, - QudaInvertParam &inv_param, void **gauge, void *clover, void *clover_inv); +std::array verifyWilsonTypeInversion(void *spinorOut, void **spinorOutMulti, void *spinorIn, void *spinorCheck, + QudaGaugeParam &gauge_param, QudaInvertParam &inv_param, void **gauge, + void *clover, void *clover_inv, int src_idx); /** * @brief Verify a staggered inversion on the host. This version is a thin wrapper around a version that takes @@ -122,7 +123,7 @@ std::array verifyWilsonTypeInversion(void *spinorOut, void **spinorOu */ std::array verifyStaggeredInversion(quda::ColorSpinorField &in, quda::ColorSpinorField &out, quda::GaugeField &fat_link, quda::GaugeField &long_link, - QudaInvertParam &inv_param); + QudaInvertParam &inv_param, int src_idx); /** * @brief Verify a single- or multi-shift staggered inversion on the host @@ -137,7 +138,7 @@ std::array verifyStaggeredInversion(quda::ColorSpinorField &in, quda: std::array verifyStaggeredInversion(quda::ColorSpinorField &in, std::vector &out_vector, quda::GaugeField &fat_link, quda::GaugeField &long_link, - QudaInvertParam &inv_param); + QudaInvertParam &inv_param, int src_idx = 0); /** * @brief Verify a staggered-type eigenvector diff --git a/tests/invert_test.cpp b/tests/invert_test.cpp index bf56dc2162..a16c765c96 100644 --- a/tests/invert_test.cpp +++ b/tests/invert_test.cpp @@ -47,6 +47,7 @@ void display_test_info() printfQuda(" - number of levels %d\n", mg_levels); for (int i = 0; i < mg_levels - 1; i++) { printfQuda(" - level %d number of null-space vectors %d\n", i + 1, nvec[i]); + printfQuda(" - level %d null-space vector batch size %d\n", i + 1, nvec_batch[i]); printfQuda(" - level %d number of pre-smoother applications %d\n", i + 1, nu_pre[i]); printfQuda(" - level %d number of post-smoother applications %d\n", i + 1, nu_post[i]); } @@ -304,7 +305,7 @@ std::vector> solve(test_t param) verifySpinorDistanceReweight(in[0], distance_pc_alpha0, distance_pc_t0); } - if (!use_split_grid) { + if (!use_split_grid && Nsrc == 1) { for (int i = 0; i < Nsrc; i++) { // If deflating, preserve the deflation space between solves @@ -333,8 +334,8 @@ std::vector> solve(test_t param) _hp_x[i] = out[i].data(); _hp_b[i] = in[i].data(); } - - // Run split grid + + // Run split grid invertMultiSrcQuda(_hp_x.data(), _hp_b.data(), &inv_param); quda::comm_allreduce_int(inv_param.iter); @@ -360,7 +361,7 @@ std::vector> solve(test_t param) if (verify_results) { for (int i = 0; i < Nsrc; i++) { res[i] = verifyInversion(out[i].data(), _hp_multi_x[i].data(), in[i].data(), check.data(), gauge_param, inv_param, - gauge.data(), clover.data(), clover_inv.data()); + gauge.data(), clover.data(), clover_inv.data(), i); } } return res; diff --git a/tests/multigrid_evolve_test.cpp b/tests/multigrid_evolve_test.cpp index 9545942a51..bc5f38cfed 100644 --- a/tests/multigrid_evolve_test.cpp +++ b/tests/multigrid_evolve_test.cpp @@ -63,6 +63,7 @@ void display_test_info() printfQuda(" - level %d number of null-space vectors %d\n", i + 1, nvec[i]); printfQuda(" - level %d number of pre-smoother applications %d\n", i + 1, nu_pre[i]); printfQuda(" - level %d number of post-smoother applications %d\n", i + 1, nu_post[i]); + printfQuda(" - level %d null-space vector batch size %d\n", i + 1, nvec_batch[i]); } printfQuda("Outer solver paramers\n"); diff --git a/tests/staggered_invert_test.cpp b/tests/staggered_invert_test.cpp index 9a00b3d875..ac4f0223ba 100644 --- a/tests/staggered_invert_test.cpp +++ b/tests/staggered_invert_test.cpp @@ -361,7 +361,7 @@ std::vector> solve(test_t param) // QUDA invert test //---------------------------------------------------------------------------- - if (!use_split_grid) { + if (!use_split_grid && Nsrc == 1) { for (int n = 0; n < Nsrc; n++) { // If deflating, preserve the deflation space between solves @@ -418,7 +418,7 @@ std::vector> solve(test_t param) = {out_multishift.begin() + n * multishift, out_multishift.begin() + (n + 1) * multishift}; res[n] = verifyStaggeredInversion(in[n], out_subset, cpuFatQDP, cpuLongQDP, inv_param); } else { - res[n] = verifyStaggeredInversion(in[n], out[n], cpuFatQDP, cpuLongQDP, inv_param); + res[n] = verifyStaggeredInversion(in[n], out[n], cpuFatQDP, cpuLongQDP, inv_param, n); } } } From c8773521590de67209f575b3cac64c5eca6f4a1c Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Thu, 25 Jul 2024 03:25:24 -0700 Subject: [PATCH 031/103] invert_test and staggered_invert_test now respect --nsrc-tile flag for setting number of concurrent solves --- tests/invert_test.cpp | 55 +++++++++++++++++++++----------- tests/staggered_invert_test.cpp | 56 ++++++++++++++++++++++----------- 2 files changed, 73 insertions(+), 38 deletions(-) diff --git a/tests/invert_test.cpp b/tests/invert_test.cpp index a16c765c96..890d357078 100644 --- a/tests/invert_test.cpp +++ b/tests/invert_test.cpp @@ -21,6 +21,7 @@ QudaInvertParam mg_inv_param; QudaEigParam mg_eig_param[QUDA_MAX_MG_LEVEL]; QudaEigParam eig_param; bool use_split_grid = false; +bool use_multi_src = false; std::vector gauge_; std::array gauge; @@ -226,6 +227,7 @@ std::vector> solve(test_t param) for (int i = 0; i < 4; i++) inv_param.split_grid[i] = grid_partition[i]; int num_sub_partition = grid_partition[0] * grid_partition[1] * grid_partition[2] * grid_partition[3]; use_split_grid = num_sub_partition > 1; + use_multi_src = use_split_grid || (Nsrc_tile > 1); // Now QUDA is initialised and the fields are loaded, we may setup the preconditioner void *mg_preconditioner = nullptr; @@ -239,6 +241,8 @@ std::vector> solve(test_t param) // Vector construct START //----------------------------------------------------------------------------------- + if (Nsrc > QUDA_MAX_MULTI_SRC) + errorQuda("Nsrc = %d which is great than QUDA_MAX_MULTI_SRC = %d\n", Nsrc, QUDA_MAX_MULTI_SRC); std::vector in(Nsrc); std::vector out(Nsrc); std::vector out_multishift(multishift * Nsrc); @@ -305,7 +309,7 @@ std::vector> solve(test_t param) verifySpinorDistanceReweight(in[0], distance_pc_alpha0, distance_pc_t0); } - if (!use_split_grid && Nsrc == 1) { + if (!use_multi_src) { for (int i = 0; i < Nsrc; i++) { // If deflating, preserve the deflation space between solves @@ -317,34 +321,47 @@ std::vector> solve(test_t param) invertQuda(out[i].data(), in[i].data(), &inv_param); } + // move residuals to i^th location for verification after solves have finised + inv_param.true_res[i] = inv_param.true_res[0]; + inv_param.true_res_hq[i] = inv_param.true_res_hq[0]; + time[i] = inv_param.secs; gflops[i] = inv_param.gflops / inv_param.secs; iter[i] = inv_param.iter; printfQuda("Done: %i iter / %g secs = %g Gflops\n", inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs); } + } else { - inv_param.num_src = Nsrc; - inv_param.num_src_per_sub_partition = Nsrc / num_sub_partition; + inv_param.num_src = Nsrc_tile; + inv_param.num_src_per_sub_partition = Nsrc_tile / num_sub_partition; // Host arrays for solutions, sources, and check - std::vector _hp_x(Nsrc); - std::vector _hp_b(Nsrc); - for (int i = 0; i < Nsrc; i++) { - _hp_x[i] = out[i].data(); - _hp_b[i] = in[i].data(); - } + std::vector _hp_x(Nsrc_tile); + std::vector _hp_b(Nsrc_tile); - // Run split grid - invertMultiSrcQuda(_hp_x.data(), _hp_b.data(), &inv_param); + for (int j = 0; j < Nsrc; j += Nsrc_tile) { + for (int i = 0; i < Nsrc_tile; i++) { + _hp_x[i] = out[j + i].data(); + _hp_b[i] = in[j + i].data(); + } - quda::comm_allreduce_int(inv_param.iter); - inv_param.iter /= quda::comm_size() / num_sub_partition; - quda::comm_allreduce_sum(inv_param.gflops); - inv_param.gflops /= quda::comm_size() / num_sub_partition; - quda::comm_allreduce_max(inv_param.secs); - printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops\n", num_sub_partition, inv_param.iter, - inv_param.secs, inv_param.gflops / inv_param.secs); + invertMultiSrcQuda(_hp_x.data(), _hp_b.data(), &inv_param); + + // move residuals to (i+j)^th location for verification after solves have finished + for (int i = 0; i < Nsrc_tile; i++) { + inv_param.true_res[j + i] = inv_param.true_res[i]; + inv_param.true_res_hq[j + i] = inv_param.true_res_hq[i]; + } + + quda::comm_allreduce_int(inv_param.iter); + inv_param.iter /= quda::comm_size() / num_sub_partition; + quda::comm_allreduce_sum(inv_param.gflops); + inv_param.gflops /= quda::comm_size() / num_sub_partition; + quda::comm_allreduce_max(inv_param.secs); + printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops\n", num_sub_partition, inv_param.iter, + inv_param.secs, inv_param.gflops / inv_param.secs); + } } // QUDA invert test COMPLETE @@ -354,7 +371,7 @@ std::vector> solve(test_t param) if (inv_multigrid) destroyMultigridQuda(mg_preconditioner); // Compute performance statistics - if (Nsrc > 1 && !use_split_grid) performanceStats(time, gflops, iter); + if (!use_multi_src) performanceStats(time, gflops, iter); std::vector> res(Nsrc); // Perform host side verification of inversion if requested diff --git a/tests/staggered_invert_test.cpp b/tests/staggered_invert_test.cpp index ac4f0223ba..a557c15797 100644 --- a/tests/staggered_invert_test.cpp +++ b/tests/staggered_invert_test.cpp @@ -25,6 +25,7 @@ QudaInvertParam mg_inv_param; QudaEigParam mg_eig_param[QUDA_MAX_MG_LEVEL]; QudaEigParam eig_param; bool use_split_grid = false; +bool use_multi_src = false; // print instructions on how to run the old tests bool print_legacy_info = false; @@ -271,6 +272,7 @@ std::vector> solve(test_t param) for (int i = 0; i < 4; i++) inv_param.split_grid[i] = grid_partition[i]; int num_sub_partition = grid_partition[0] * grid_partition[1] * grid_partition[2] * grid_partition[3]; use_split_grid = num_sub_partition > 1; + use_multi_src = use_split_grid || (Nsrc_tile > 1); // Setup the multigrid preconditioner void *mg_preconditioner = nullptr; @@ -284,6 +286,8 @@ std::vector> solve(test_t param) // Staggered vector construct START //----------------------------------------------------------------------------------- + if (Nsrc > QUDA_MAX_MULTI_SRC) + errorQuda("Nsrc = %d which is great than QUDA_MAX_MULTI_SRC = %d\n", Nsrc, QUDA_MAX_MULTI_SRC); std::vector in(Nsrc); std::vector out(Nsrc); std::vector out_multishift(Nsrc * multishift); @@ -361,7 +365,7 @@ std::vector> solve(test_t param) // QUDA invert test //---------------------------------------------------------------------------- - if (!use_split_grid && Nsrc == 1) { + if (!use_multi_src) { for (int n = 0; n < Nsrc; n++) { // If deflating, preserve the deflation space between solves @@ -373,6 +377,10 @@ std::vector> solve(test_t param) invertQuda(out[n].data(), in[n].data(), &inv_param); } + // move residuals to n^th location for verification after solves have finished + inv_param.true_res[n] = inv_param.true_res[0]; + inv_param.true_res_hq[n] = inv_param.true_res_hq[0]; + time[n] = inv_param.secs; gflops[n] = inv_param.gflops / inv_param.secs; iter[n] = inv_param.iter; @@ -380,32 +388,42 @@ std::vector> solve(test_t param) inv_param.gflops / inv_param.secs); } } else { - inv_param.num_src = Nsrc; - inv_param.num_src_per_sub_partition = Nsrc / num_sub_partition; + + inv_param.num_src = Nsrc_tile; + inv_param.num_src_per_sub_partition = Nsrc_tile / num_sub_partition; // Host arrays for solutions, sources, and check - std::vector _hp_x(Nsrc); - std::vector _hp_b(Nsrc); - for (int n = 0; n < Nsrc; n++) { - _hp_x[n] = out[n].data(); - _hp_b[n] = in[n].data(); + std::vector _hp_x(Nsrc_tile); + std::vector _hp_b(Nsrc_tile); + + for (int j = 0; j < Nsrc; j += Nsrc_tile) { + for (int i = 0; i < Nsrc_tile; i++) { + _hp_x[i] = out[j + i].data(); + _hp_b[i] = in[j + i].data(); + } + + invertMultiSrcQuda(_hp_x.data(), _hp_b.data(), &inv_param); + + // move residuals to (i+j)^th location for verification after solves have finished + for (int i = 0; i < Nsrc_tile; i++) { + inv_param.true_res[j + i] = inv_param.true_res[i]; + inv_param.true_res_hq[j + i] = inv_param.true_res_hq[i]; + } + + quda::comm_allreduce_int(inv_param.iter); + inv_param.iter /= comm_size() / num_sub_partition; + quda::comm_allreduce_sum(inv_param.gflops); + inv_param.gflops /= comm_size() / num_sub_partition; + quda::comm_allreduce_max(inv_param.secs); + printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops\n\n", num_sub_partition, inv_param.iter, + inv_param.secs, inv_param.gflops / inv_param.secs); } - // Run split grid - invertMultiSrcQuda(_hp_x.data(), _hp_b.data(), &inv_param); - - quda::comm_allreduce_int(inv_param.iter); - inv_param.iter /= comm_size() / num_sub_partition; - quda::comm_allreduce_sum(inv_param.gflops); - inv_param.gflops /= comm_size() / num_sub_partition; - quda::comm_allreduce_max(inv_param.secs); - printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops\n\n", num_sub_partition, inv_param.iter, - inv_param.secs, inv_param.gflops / inv_param.secs); } // Free the multigrid solver if (inv_multigrid) destroyMultigridQuda(mg_preconditioner); // Compute timings - if (Nsrc > 1 && !use_split_grid) performanceStats(time, gflops, iter); + if (!use_multi_src) performanceStats(time, gflops, iter); std::vector> res(Nsrc); // Perform host side verification of inversion if requested From 19d33480d706dc4d766b20c4b0e1c09af02507fb Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 26 Jul 2024 01:11:49 -0700 Subject: [PATCH 032/103] Add some size checks to P and R --- lib/transfer.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/transfer.cpp b/lib/transfer.cpp index ae106b92ac..b3925ccc4c 100644 --- a/lib/transfer.cpp +++ b/lib/transfer.cpp @@ -339,6 +339,7 @@ namespace quda { // apply the prolongator void Transfer::P(cvector_ref &out, cvector_ref &in) const { getProfile().TPSTART(QUDA_PROFILE_COMPUTE); + if (out.size() != in.size()) errorQuda("Mismatched set sizes %lu != %lu", out.size(), in.size()); initializeLazy(use_gpu ? QUDA_CUDA_FIELD_LOCATION : QUDA_CPU_FIELD_LOCATION); const int *fine_to_coarse = use_gpu ? fine_to_coarse_d : fine_to_coarse_h; @@ -413,6 +414,7 @@ namespace quda { void Transfer::R(cvector_ref &out, cvector_ref &in) const { getProfile().TPSTART(QUDA_PROFILE_COMPUTE); + if (out.size() != in.size()) errorQuda("Mismatched set sizes %lu != %lu", out.size(), in.size()); initializeLazy(use_gpu ? QUDA_CUDA_FIELD_LOCATION : QUDA_CPU_FIELD_LOCATION); const int *fine_to_coarse = use_gpu ? fine_to_coarse_d : fine_to_coarse_h; From 3583ef801685eb157373a152dc0929eca7162a31 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 26 Jul 2024 02:56:56 -0700 Subject: [PATCH 033/103] Use batched blas in DiracCoarse --- lib/dirac_coarse.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/dirac_coarse.cpp b/lib/dirac_coarse.cpp index 55e827b0ca..7bf1225f7c 100644 --- a/lib/dirac_coarse.cpp +++ b/lib/dirac_coarse.cpp @@ -528,7 +528,7 @@ namespace quda { { // FIXME emulated for now Dslash(out, in, parity); - for (auto i = 0u; i < x.size(); i++) blas::xpay(x[i], k, out[i]); + blas::xpay(x, k, out); } void DiracCoarsePC::M(cvector_ref &out, cvector_ref &in) const @@ -545,14 +545,14 @@ namespace quda { // DiracCoarse::DslashXpay applies (A - D) // FIXME this ignores the -1 DiracCoarse::Dslash(out, tmp, QUDA_EVEN_PARITY); Clover(tmp, in, QUDA_EVEN_PARITY); - for (auto i = 0u; i < in.size(); i++) blas::xpay(tmp[i], -1.0, out[i]); + blas::xpay(tmp, -1.0, out); } else if (matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) { // DiracCoarsePC::Dslash applies A^{-1}Dslash Dslash(tmp, in, QUDA_EVEN_PARITY); // DiracCoarse::DslashXpay applies (A - D) // FIXME this ignores the -1 DiracCoarse::Dslash(out, tmp, QUDA_ODD_PARITY); Clover(tmp, in, QUDA_ODD_PARITY); - for (auto i = 0u; i < in.size(); i++) blas::xpay(tmp[i], -1.0, out[i]); + blas::xpay(tmp, -1.0, out); } else if (matpcType == QUDA_MATPC_EVEN_EVEN) { Dslash(tmp, in, QUDA_ODD_PARITY); DslashXpay(out, tmp, QUDA_EVEN_PARITY, in, -1.0); From f64a9ac77faee5e18054bc3d1f62897c0a14c04e Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 26 Jul 2024 04:54:17 -0700 Subject: [PATCH 034/103] Set verbosity in solve() --- lib/solve.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/solve.cpp b/lib/solve.cpp index c0a2ac422b..730152f60d 100644 --- a/lib/solve.cpp +++ b/lib/solve.cpp @@ -315,6 +315,8 @@ namespace quda void solve(const std::vector &hp_x, const std::vector &hp_b, QudaInvertParam ¶m, const GaugeField &u) { + pushVerbosity(param.verbosity); + if (hp_b.size() != hp_x.size()) errorQuda("Number of solutions %lu != number of solves %lu", hp_x.size(), hp_b.size()); int n_src = hp_b.size(); @@ -402,5 +404,7 @@ namespace quda delete diracSloppy; delete diracPre; delete diracEig; + + popVerbosity(); } } // namespace quda From fa26c4ea01c2e3ceff7413ad5b92b2908be9f27b Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 26 Jul 2024 05:03:41 -0700 Subject: [PATCH 035/103] GCR, CA-GCR and PreconditionedSolver are now MRHS aware --- include/invert_quda.h | 102 +++++++++++++++------------- lib/inv_ca_gcr.cpp | 111 ++++++++++++++++++------------ lib/inv_gcr_quda.cpp | 154 ++++++++++++++++++++++++------------------ lib/solver.hpp | 10 +-- 4 files changed, 219 insertions(+), 158 deletions(-) diff --git a/include/invert_quda.h b/include/invert_quda.h index b49ea28d1f..62d9019523 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -670,10 +670,8 @@ namespace quda { @brief Compute power iterations on a Dirac matrix @param[in] diracm Dirac matrix used for power iterations @param[in] start Starting rhs for power iterations; value preserved unless it aliases tempvec1 or tempvec2 - @param[in,out] tempvec1 Temporary vector used for power iterations (FIXME: can become a reference when std::swap - can be used on ColorSpinorField) - @param[in,out] tempvec2 Temporary vector used for power iterations (FIXME: can become a reference when std::swap - can be used on ColorSpinorField) + @param[in,out] tempvec1 Temporary vector used for power iterations + @param[in,out] tempvec2 Temporary vector used for power iterations @param[in] niter Total number of power iteration iterations @param[in] normalize_freq Frequency with which intermediate vector gets normalized @param[in] args Parameter pack of ColorSpinorFields used as temporary passed to Dirac @@ -687,8 +685,8 @@ namespace quda { /** @brief Generate a Krylov space in a given basis @param[in] diracm Dirac matrix used to generate the Krylov space - @param[out] Ap dirac matrix times the Krylov basis vectors - @param[in,out] p Krylov basis vectors; assumes p[0] is in place + @param[out] Ap dirac matrix times the Krylov basis vector sets + @param[in,out] p Krylov basis vector sets; assumes p[0] is in place @param[in] n_krylov Size of krylov space @param[in] basis Basis type @param[in] m_map Slope mapping for Chebyshev basis; ignored for power basis @@ -696,9 +694,32 @@ namespace quda { @param[in] args Parameter pack of ColorSpinorFields used as temporary passed to Dirac */ template - static void computeCAKrylovSpace(const DiracMatrix &diracm, std::vector &Ap, - std::vector &p, int n_krylov, QudaCABasis basis, double m_map, - double b_map, Args &&...args); + static void computeCAKrylovSpace(const DiracMatrix &diracm, std::vector> &Ap, + std::vector> &p, int n_krylov, QudaCABasis basis, + double m_map, double b_map, Args &&...args); + + // FIXME delete this variant once CA-CG is MRHS aware + template + void computeCAKrylovSpace(const DiracMatrix &diracm, std::vector &Ap, + std::vector &p, int n_krylov, QudaCABasis basis, double m_map, + double b_map, Args &&...args) + { + std::vector> p2(p.size()); + for (auto i = 0u; i < p.size(); i++) { + p2[i].resize(1); + p2[i][0] = std::move(p[i]); + } + std::vector> Ap2(Ap.size()); + for (auto i = 0u; i < Ap.size(); i++) { + Ap2[i].resize(1); + Ap2[i][0] = std::move(Ap[i]); + } + + computeCAKrylovSpace(diracm, Ap2, p2, n_krylov, basis, m_map, b_map, args...); + + for (auto i = 0u; i < p.size(); i++) p[i] = std::move(p2[i][0]); + for (auto i = 0u; i < Ap.size(); i++) Ap[i] = std::move(Ap2[i][0]); + } }; /** @@ -1200,36 +1221,36 @@ namespace quda { */ int n_krylov; - std::vector alpha; - std::vector beta; - std::vector gamma; + std::vector> alpha; + std::vector> beta; + std::vector> gamma; /** Solver uses lazy allocation: this flag to determine whether we have allocated. */ bool init = false; - ColorSpinorField r; //! residual vector - ColorSpinorField r_sloppy; //! sloppy residual vector + std::vector r; //! residual vector + std::vector r_sloppy; //! sloppy residual vector int k_break = 0; //! track when the solver converged - std::vector p; // GCR direction vectors - std::vector Ap; // mat * direction vectors + std::vector> p; // GCR direction vectors + std::vector> Ap; // mat * direction vectors - void computeBeta(std::vector &beta, std::vector &Ap, int i, int N, int k); - void updateAp(std::vector &beta, std::vector &Ap, int begin, int size, int k); - void orthoDir(std::vector &beta, std::vector &Ap, int k, int pipeline); + void computeBeta(std::vector &beta, cvector_ref &Ap, int i, int N, int k); + void updateAp(std::vector &beta, cvector_ref &Ap, int begin, int size, int k); + void orthoDir(std::vector &beta, cvector_ref &Ap, int k, int pipeline); void backSubs(const std::vector &alpha, const std::vector &beta, const std::vector &gamma, std::vector &delta, int n); void updateSolution(ColorSpinorField &x, const std::vector &alpha, const std::vector &beta, - std::vector &gamma, int k, std::vector &p); + std::vector &gamma, int k, cvector_ref &p); /** @brief Initiate the fields needed by the solver - @param[in] x Solution vector - @param[in] b Source vector + @param[in] x Solution vector set + @param[in] b Source vector set */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: GCR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, @@ -1242,10 +1263,7 @@ namespace quda { const DiracMatrix &matEig, SolverParam ¶m); virtual ~GCR(); - void operator()(cvector_ref &out, cvector_ref &in) override - { - for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); - } + void operator()(cvector_ref &out, cvector_ref &in) override; void operator()(ColorSpinorField &out, const ColorSpinorField &in); @@ -1279,13 +1297,7 @@ namespace quda { MR(const DiracMatrix &mat, const DiracMatrix &matSloppy, SolverParam ¶m); void operator()(cvector_ref &out, cvector_ref &in) override; -#if 0 - { - for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); - } - void operator()(ColorSpinorField &out, const ColorSpinorField &in); -#endif /** @return Return the residual vector from the prior solve */ @@ -1467,19 +1479,19 @@ namespace quda { bool lambda_init; // whether or not lambda_max has been initialized QudaCABasis basis; // CA basis - std::vector alpha; // Solution coefficient vectors + std::vector> alpha; // Solution coefficient vectors - ColorSpinorField r; + std::vector r; - std::vector p; // GCR direction vectors - std::vector q; // mat * direction vectors + std::vector> p; // GCR direction vectors + std::vector> q; // mat * direction vectors /** @brief Initiate the fields needed by the solver @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); /** @brief Solve the equation A p_k psi_k = q_k psi_k = b by minimizing the @@ -1488,19 +1500,14 @@ namespace quda { @param[in] q Search direction vectors with the operator applied @param[in] b Source vector against which we are solving */ - void solve(std::vector &psi, std::vector &q, ColorSpinorField &b); + void solve(std::vector &psi, cvector_ref &q, ColorSpinorField &b); public: CAGCR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); virtual ~CAGCR(); - void operator()(cvector_ref &out, cvector_ref &in) override - { - for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); - } - - void operator()(ColorSpinorField &out, const ColorSpinorField &in); + void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual vector from the prior solve @@ -1561,12 +1568,13 @@ namespace quda { void operator()(cvector_ref &x, cvector_ref &b) override { + if (x.size() != b.size()) errorQuda("Mismatched set sizes %lu != %lu", x.size(), b.size()); pushOutputPrefix(prefix); QudaSolutionType solution_type = b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? QUDA_MAT_SOLUTION : QUDA_MATPC_SOLUTION; - ColorSpinorField out; - ColorSpinorField in; + std::vector out(b.size()); + std::vector in(b.size()); if (dirac.hasSpecialMG()) { dirac.prepareSpecialMG(out, in, x, b, solution_type); diff --git a/lib/inv_ca_gcr.cpp b/lib/inv_ca_gcr.cpp index abd68aad3d..1b5fdb372e 100644 --- a/lib/inv_ca_gcr.cpp +++ b/lib/inv_ca_gcr.cpp @@ -21,18 +21,19 @@ namespace quda destroyDeflationSpace(); } - void CAGCR::create(ColorSpinorField &x, const ColorSpinorField &b) + void CAGCR::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); if (!init) { if (!param.is_preconditioner) getProfile().TPSTART(QUDA_PROFILE_INIT); - alpha.resize(param.Nkrylov); + alpha.resize(b.size()); + for (auto i = 0u; i < alpha.size(); i++) alpha[i].resize(param.Nkrylov); bool mixed = param.precision != param.precision_sloppy; - ColorSpinorParam csParam(b); + ColorSpinorParam csParam(b[0]); csParam.create = QUDA_NULL_FIELD_CREATE; csParam.setPrecision(param.precision_sloppy); @@ -41,27 +42,30 @@ namespace quda p.resize(param.Nkrylov + 1); q.resize(param.Nkrylov); for (int i = 0; i < param.Nkrylov + 1; i++) { - p[i] = ColorSpinorField(csParam); - if (i > 0) q[i - 1] = p[i].create_alias(csParam); + resize(p[i], b.size(), csParam); + if (i > 0) create_alias(q[i - 1], p[i]); } } else { p.resize(param.Nkrylov); q.resize(param.Nkrylov); for (int i = 0; i < param.Nkrylov; i++) { - p[i] = ColorSpinorField(csParam); - q[i] = ColorSpinorField(csParam); + resize(p[i], b.size(), csParam); + resize(q[i], b.size(), csParam); } } csParam.setPrecision(param.precision); - r = mixed ? ColorSpinorField(csParam) : p[0].create_alias(csParam); + if (mixed) + resize(r, b.size(), csParam); + else + create_alias(r, p[0]); if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_INIT); init = true; } // init } - void CAGCR::solve(std::vector &psi_, std::vector &q, ColorSpinorField &b) + void CAGCR::solve(std::vector &psi_, cvector_ref &q, ColorSpinorField &b) { typedef Matrix matrix; typedef Matrix vector; @@ -127,7 +131,7 @@ namespace quda 3. Update solution and residual vectors 4. (Optional) restart if convergence or maxiter not reached */ - void CAGCR::operator()(ColorSpinorField &x, const ColorSpinorField &b) + void CAGCR::operator()(cvector_ref &x, cvector_ref &b) { const int n_krylov = param.Nkrylov; @@ -143,16 +147,16 @@ namespace quda // compute b2, but only if we need to bool fixed_iteration = param.sloppy_converge && n_krylov == param.maxiter && !param.compute_true_res; - double b2 = !fixed_iteration ? blas::norm2(b) : 1.0; - double r2 = 0.0; // if zero source then we will exit immediately doing no work + auto b2 = !fixed_iteration ? blas::norm2(b) : vector(b.size(), 1.0); + std::vector r2(b.size(), 0.0); // if zero source then we will exit immediately doing no work if (param.deflate) { // Construct the eigensolver and deflation space if requested. if (param.eig_param.eig_type == QUDA_EIG_TR_LANCZOS || param.eig_param.eig_type == QUDA_EIG_BLK_TR_LANCZOS) { - constructDeflationSpace(b, matMdagM); + constructDeflationSpace(b[0], matMdagM); } else { // Use Arnoldi to inspect the space only and turn off deflation - constructDeflationSpace(b, mat); + constructDeflationSpace(b[0], mat); param.deflate = false; } if (deflate_compute) { @@ -217,7 +221,7 @@ namespace quda // Perform 100 power iterations, normalizing every 10 mat-vecs, using r_ as an initial seed // and q[0]/q[1] as temporaries for the power iterations. Technically illegal if n_krylov == 1, but in that case lambda_max isn't used anyway. - lambda_max = 1.1 * Solver::performPowerIterations(matSloppy, r, q[0], q[1], 100, 10); + lambda_max = 1.1 * Solver::performPowerIterations(matSloppy, r[0], q[0][0], q[1][0], 100, 10); logQuda(QUDA_SUMMARIZE, "CA-GCR Approximate lambda max = 1.1 x %e\n", lambda_max / 1.1); lambda_init = true; @@ -233,19 +237,26 @@ namespace quda double b_map = -(lambda_max + lambda_min) / (lambda_max - lambda_min); // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0) { - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - warningQuda("inverting on zero-field source\n"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; + if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { + bool zero_src = true; + for (auto i = 0u; i < b.size(); i++) { + if (b2[i] == 0) { + warningQuda("inverting on zero-field source"); + x[i] = b[i]; + param.true_res[i] = 0.0; + param.true_res_hq[i] = 0.0; + } else { + zero_src = false; + } + } + if (zero_src) { + getProfile().TPSTOP(QUDA_PROFILE_INIT); return; - } else { - b2 = r2; } } - double stop = !fixed_iteration ? stopping(param.tol, b2, param.residual_type) : 0.0; // stopping condition of solver + auto stop = !fixed_iteration ? stopping(param.tol, b2, param.residual_type) : + std::vector(b.size(), 0.0); // stopping condition of solver const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; @@ -255,8 +266,9 @@ namespace quda const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance const int maxResIncreaseTotal = param.max_res_increase_total; - double heavy_quark_res = 0.0; // heavy quark residual - if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + std::vector heavy_quark_res(b.size()); // heavy quark residual + if (use_heavy_quark_res) + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(blas::HeavyQuarkResidualNorm(x, r)[i].z); int resIncrease = 0; int resIncreaseTotal = 0; @@ -267,29 +279,41 @@ namespace quda } int total_iter = 0; int restart = 0; - double r2_old = r2; - double maxr_deflate = sqrt(r2); + auto r2_old = r2; + auto maxr_deflate = sqrt(r2[0]); bool l2_converge = false; blas::copy(p[0], r); // no op if uni-precision + auto get_i = [](std::vector> &p, int i) { + vector_ref p_i; + p_i.reserve(p.size()); + for (auto &pi : p) p_i.push_back(pi[i]); + return p_i; + }; + PrintStats("CA-GCR", total_iter, r2, b2, heavy_quark_res); while (!convergence(r2, heavy_quark_res, stop, param.tol_hq) && total_iter < param.maxiter) { // build up a space of size n_krylov computeCAKrylovSpace(matSloppy, q, p, n_krylov, basis, m_map, b_map); - solve(alpha, q, p[0]); + for (auto i = 0u; i < b.size(); i++) solve(alpha[i], get_i(q, i), p[0][i]); // need to make sure P is only length n_krylov - blas::block::caxpy(alpha, {p.begin(), p.begin() + n_krylov}, {x}); + for (auto i = 0u; i < b.size(); i++) { + auto pi = get_i(p, i); + blas::block::caxpy(alpha[i], {pi.begin(), pi.begin() + n_krylov}, {x[i]}); + } // no need to compute residual vector if not returning // residual vector and only doing a single fixed iteration if (!fixed_iteration || param.return_residual) { - // update the residual vector - for (int i = 0; i < n_krylov; i++) alpha[i] = -alpha[i]; - blas::block::caxpy(alpha, q, r); + for (auto i = 0u; i < b.size(); i++) { + // update the residual vector + for (int j = 0; j < n_krylov; j++) alpha[i][j] = -alpha[i][j]; + blas::block::caxpy(alpha[i], get_i(q, i), r[i]); + } } total_iter += n_krylov; @@ -302,13 +326,13 @@ namespace quda // update since n_krylov or maxiter reached, converged or reliable update required // note that the heavy quark residual will by definition only be checked every n_krylov steps - if (total_iter >= param.maxiter || (r2 < stop && !l2_converge) || sqrt(r2 / r2_old) < param.delta) { + if (total_iter >= param.maxiter || (r2 < stop && !l2_converge) || sqrt(r2[0] / r2_old[0]) < param.delta) { if ((r2 < stop || total_iter >= param.maxiter) && param.sloppy_converge) break; mat(r, x); r2 = blas::xmyNorm(b, r); - if (param.deflate && sqrt(r2) < maxr_deflate * param.tol_restart) { + if (param.deflate && sqrt(r2[0]) < maxr_deflate * param.tol_restart) { // Deflate and accumulate to solution vector eig_solve->deflateSVD(x, r, evecs, evals, true); @@ -316,10 +340,11 @@ namespace quda mat(r, x); r2 = blas::xmyNorm(b, r); - maxr_deflate = sqrt(r2); + maxr_deflate = sqrt(r2[0]); } - if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + if (use_heavy_quark_res) + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(blas::HeavyQuarkResidualNorm(x, r)[i].z); // break-out check if we have reached the limit of the precision if (r2 > r2_old) { @@ -327,7 +352,7 @@ namespace quda resIncreaseTotal++; warningQuda( "CA-GCR: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", - sqrt(r2), sqrt(r2_old), resIncreaseTotal); + sqrt(r2[0]), sqrt(r2_old[0]), resIncreaseTotal); if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) { warningQuda("CA-GCR: solver exiting due to too many true residual norm increases"); break; @@ -361,10 +386,12 @@ namespace quda if (param.compute_true_res) { // Calculate the true residual mat(r, x); - double true_res = blas::xmyNorm(b, r); - param.true_res = sqrt(true_res / b2); - param.true_res_hq - = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? sqrt(blas::HeavyQuarkResidualNorm(x, r).z) : 0.0; + auto true_r2 = blas::xmyNorm(b, r); + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) { + param.true_res[i] = sqrt(true_r2[i] / b2[i]); + param.true_res_hq[i] = sqrt(hq[i].z); + } } if (!param.is_preconditioner) { diff --git a/lib/inv_gcr_quda.cpp b/lib/inv_gcr_quda.cpp index aa9920a123..f040e899be 100644 --- a/lib/inv_gcr_quda.cpp +++ b/lib/inv_gcr_quda.cpp @@ -21,7 +21,7 @@ namespace quda { return ds + 0.000001*dus; } - void GCR::computeBeta(std::vector &beta, std::vector &Ap, int i, int N, int k) + void GCR::computeBeta(std::vector &beta, cvector_ref &Ap, int i, int N, int k) { std::vector Beta(N, 0.0); blas::block::cDotProduct(Beta, {Ap.begin() + i, Ap.begin() + i + N}, Ap[k]); // vectorized dot product @@ -36,14 +36,14 @@ namespace quda { for (int j = 0; j < N; j++) beta[(i + j) * n_krylov + k] = Beta[j]; } - void GCR::updateAp(std::vector &beta, std::vector &Ap, int begin, int size, int k) + void GCR::updateAp(std::vector &beta, cvector_ref &Ap, int begin, int size, int k) { std::vector beta_(size); for (int i = 0; i < size; i++) beta_[i] = -beta[(i + begin) * n_krylov + k]; blas::block::caxpy(beta_, {Ap.begin() + begin, Ap.begin() + begin + size}, Ap[k]); } - void GCR::orthoDir(std::vector &beta, std::vector &Ap, int k, int pipeline) + void GCR::orthoDir(std::vector &beta, cvector_ref &Ap, int k, int pipeline) { switch (pipeline) { case 0: // no kernel fusion @@ -64,8 +64,8 @@ namespace quda { { const int N = pipeline; for (int i=0; i &alpha, const std::vector &beta, - std::vector &gamma, int k, std::vector &p) + std::vector &gamma, int k, cvector_ref &p) { std::vector delta(k); @@ -109,10 +109,7 @@ namespace quda { matMdagM(DiracMdagM(matEig.Expose())), K(0), Kparam(param), - n_krylov(param.Nkrylov), - alpha(n_krylov), - beta(n_krylov * n_krylov), - gamma(n_krylov) + n_krylov(param.Nkrylov) { fillInnerSolverParam(Kparam, param); @@ -128,10 +125,7 @@ namespace quda { matMdagM(matEig.Expose()), K(nullptr), Kparam(param), - n_krylov(param.Nkrylov), - alpha(n_krylov), - beta(n_krylov * n_krylov), - gamma(n_krylov) + n_krylov(param.Nkrylov) { fillInnerSolverParam(Kparam, param); K = wrapExternalPreconditioner(K_); @@ -142,35 +136,49 @@ namespace quda { destroyDeflationSpace(); } - void GCR::create(ColorSpinorField &x, const ColorSpinorField &b) + void GCR::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); if (!init) { getProfile().TPSTART(QUDA_PROFILE_INIT); - ColorSpinorParam csParam(x); + ColorSpinorParam csParam(x[0]); csParam.create = QUDA_NULL_FIELD_CREATE; // create sloppy fields used for orthogonalization csParam.setPrecision(param.precision_sloppy); - resize(p, n_krylov + 1, QUDA_NULL_FIELD_CREATE, csParam); - resize(Ap, n_krylov, QUDA_NULL_FIELD_CREATE, csParam); + p.resize(n_krylov + 1); + Ap.resize(n_krylov); + for (auto &p_ : p) resize(p_, b.size(), QUDA_NULL_FIELD_CREATE, csParam); + for (auto &ap : Ap) resize(ap, b.size(), QUDA_NULL_FIELD_CREATE, csParam); csParam.setPrecision(param.precision); if (K || mixed()) { - r = ColorSpinorField(csParam); + resize(r, b.size(), csParam); } else { - r = p[0].create_alias(); + create_alias(r, p[0]); } csParam.setPrecision(param.precision_sloppy); if (!K) { - r_sloppy = p[0].create_alias(); + create_alias(r_sloppy, p[0]); + } else if (!mixed()) { + create_alias(r_sloppy, r); } else { - r_sloppy = mixed() ? ColorSpinorField(csParam) : r.create_alias(); + resize(r_sloppy, b.size(), csParam); } getProfile().TPSTOP(QUDA_PROFILE_INIT); + + alpha.resize(b.size()); + beta.resize(b.size()); + gamma.resize(b.size()); + for (auto i = 0u; i < b.size(); i++) { + alpha[i].resize(n_krylov); + beta[i].resize(n_krylov * n_krylov); + gamma[i].resize(n_krylov); + } + init = true; } } @@ -184,7 +192,7 @@ namespace quda { return K ? r_sloppy : p[k_break]; } - void GCR::operator()(ColorSpinorField &x, const ColorSpinorField &b) + void GCR::operator()(cvector_ref &x, cvector_ref &b) { if (n_krylov == 0) { // Krylov space is zero-dimensional so return doing no work @@ -198,15 +206,14 @@ namespace quda { if (param.deflate) { // Construct the eigensolver and deflation space if requested. if (param.eig_param.eig_type == QUDA_EIG_TR_LANCZOS || param.eig_param.eig_type == QUDA_EIG_BLK_TR_LANCZOS) { - constructDeflationSpace(b, matMdagM); + constructDeflationSpace(b[0], matMdagM); } else { // Use Arnoldi to inspect the space only and turn off deflation - constructDeflationSpace(b, mat); + constructDeflationSpace(b[0], mat); param.deflate = false; } if (deflate_compute) { // compute the deflation space. - getProfile().TPSTOP(QUDA_PROFILE_INIT); (*eig_solve)(evecs, evals); if (param.deflate) { // double the size of the Krylov space @@ -214,7 +221,6 @@ namespace quda { // populate extra memory with L/R singular vectors eig_solve->computeSVD(evecs, evals); } - getProfile().TPSTART(QUDA_PROFILE_INIT); deflate_compute = false; } if (recompute_evals) { @@ -224,8 +230,8 @@ namespace quda { } } - double b2 = blas::norm2(b); // norm sq of source - double r2; // norm sq of residual + vector b2 = blas::norm2(b); // norm sq of source + vector r2; // norm sq of residual // compute initial residual depending on whether we have an initial guess or not if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { @@ -249,20 +255,25 @@ namespace quda { } // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0) { - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { + if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { + bool zero_src = true; + for (auto i = 0u; i < b.size(); i++) { + if (b2[i] == 0) { + warningQuda("inverting on zero-field source"); + x[i] = b[i]; + param.true_res[i] = 0.0; + param.true_res_hq[i] = 0.0; + } else { + zero_src = false; + } + } + if (zero_src) { getProfile().TPSTOP(QUDA_PROFILE_INIT); - warningQuda("inverting on zero-field source\n"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; - return; - } else { - b2 = r2; + return; } } - double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; @@ -273,8 +284,9 @@ namespace quda { const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance const int maxResIncreaseTotal = param.max_res_increase_total; - double heavy_quark_res = 0.0; // heavy quark residual - if(use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x,r).z); + std::vector heavy_quark_res(b.size()); // heavy quark residual + if (use_heavy_quark_res) + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(blas::HeavyQuarkResidualNorm(x, r)[i].z); int resIncrease = 0; int resIncreaseTotal = 0; @@ -286,13 +298,11 @@ namespace quda { int total_iter = 0; int restart = 0; - double r2_old = r2; - double maxr_deflate = sqrt(r2); + auto r2_old = r2; + double maxr_deflate = sqrt(r2[0]); bool l2_converge = false; int pipeline = param.pipeline; - // Vectorized dot product only has limited support so work around - if (Ap[0].Location() == QUDA_CPU_FIELD_LOCATION || pipeline == 0) pipeline = 1; if (pipeline > n_krylov) pipeline = n_krylov; getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); @@ -314,16 +324,29 @@ namespace quda { matSloppy(Ap[k], p[k]); - orthoDir(beta, Ap, k, pipeline); + auto get_i = [](std::vector> &p, int i) { + vector_ref p_i; + p_i.reserve(p.size()); + for (auto &pi : p) p_i.push_back(pi[i]); + return p_i; + }; - double3 Apr = blas::cDotProductNormA(Ap[k], K ? r_sloppy : p[k]); + for (auto i = 0u; i < b.size(); i++) orthoDir(beta[i], get_i(Ap, i), k, pipeline); - gamma[k] = sqrt(Apr.z); // gamma[k] = Ap[k] - if (gamma[k] == 0.0) errorQuda("GCR breakdown"); - alpha[k] = Complex(Apr.x, Apr.y) / gamma[k]; // alpha = (1/|Ap|) * (Ap, r) + auto Apr = blas::cDotProductNormA(Ap[k], K ? r_sloppy : p[k]); + + for (auto i = 0u; i < b.size(); i++) { + gamma[i][k] = sqrt(Apr[i].z); // gamma[k] = Ap[k] + if (gamma[i][k] == 0.0) errorQuda("GCR breakdown"); + alpha[i][k] = Complex(Apr[i].x, Apr[i].y) / gamma[i][k]; // alpha = (1/|Ap|) * (Ap, r) + } // r -= (1/|Ap|^2) * (Ap, r) r, Ap *= 1/|Ap| - r2 = blas::cabxpyzAxNorm(1.0 / gamma[k], -alpha[k], Ap[k], K ? r_sloppy : p[k], K ? r_sloppy : p[k + 1]); + std::vector gamma_k_inv(b.size()); + for (auto i = 0u; i < gamma_k_inv.size(); i++) gamma_k_inv[i] = 1.0 / gamma[i][k]; + std::vector alpha_k(b.size()); + for (auto i = 0u; i < alpha_k.size(); i++) alpha_k[i] = -alpha[i][k]; + r2 = blas::cabxpyzAxNorm(gamma_k_inv, alpha_k, Ap[k], K ? r_sloppy : p[k], K ? r_sloppy : p[k + 1]); k++; total_iter++; @@ -332,16 +355,17 @@ namespace quda { // update since n_krylov or maxiter reached, converged or reliable update required // note that the heavy quark residual will by definition only be checked every n_krylov steps - if (k == n_krylov || total_iter == param.maxiter || (r2 < stop && !l2_converge) || sqrt(r2 / r2_old) < param.delta) { + if (k == n_krylov || total_iter == param.maxiter || (r2[0] < stop[0] && !l2_converge) + || sqrt(r2[0] / r2_old[0]) < param.delta) { // update the solution vector - updateSolution(x, alpha, beta, gamma, k, p); + for (auto i = 0u; i < b.size(); i++) updateSolution(x[i], alpha[i], beta[i], gamma[i], k, get_i(p, i)); if ( (r2 < stop || total_iter==param.maxiter) && param.sloppy_converge) break; mat(r, x); r2 = blas::xmyNorm(b, r); - if (param.deflate && sqrt(r2) < maxr_deflate * param.tol_restart) { + if (param.deflate && sqrt(r2[0]) < maxr_deflate * param.tol_restart) { // Deflate: Hardcoded to SVD. eig_solve->deflateSVD(x, r, evecs, evals, true); @@ -349,17 +373,19 @@ namespace quda { mat(r, x); r2 = blas::xmyNorm(b, r); - maxr_deflate = sqrt(r2); + maxr_deflate = sqrt(r2[0]); } - if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + if (use_heavy_quark_res) + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(blas::HeavyQuarkResidualNorm(x, r)[i].z); // break-out check if we have reached the limit of the precision if (r2 > r2_old) { resIncrease++; resIncreaseTotal++; - warningQuda("GCR: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", - sqrt(r2), sqrt(r2_old), resIncreaseTotal); + warningQuda( + "GCR: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", + sqrt(r2[0]), sqrt(r2_old[0]), resIncreaseTotal); if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) { warningQuda("GCR: solver exiting due to too many true residual norm increases"); break; @@ -397,12 +423,12 @@ namespace quda { if (param.compute_true_res) { // Calculate the true residual mat(r, x); - double true_res = blas::xmyNorm(b, r); - param.true_res = sqrt(true_res / b2); - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) - param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x,r).z); - else - param.true_res_hq = 0.0; + auto true_r2 = blas::xmyNorm(b, r); + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) { + param.true_res[i] = sqrt(true_r2[i] / b2[i]); + param.true_res_hq[i] = sqrt(hq[i].z); + } } param.iter += total_iter; diff --git a/lib/solver.hpp b/lib/solver.hpp index ab93451aef..4bf6ffb618 100644 --- a/lib/solver.hpp +++ b/lib/solver.hpp @@ -51,8 +51,8 @@ namespace quda /** @brief Generate a Krylov space in a given basis @param[in] diracm Dirac matrix used to generate the Krylov space - @param[out] Ap dirac matrix times the Krylov basis vectors - @param[in,out] p Krylov basis vectors; assumes p[0] is in place + @param[out] Ap dirac matrix times the Krylov basis vector sets + @param[in,out] p Krylov basis vector sest; assumes p[0] is in place @param[in] n_krylov Size of krylov space @param[in] basis Basis type @param[in] m_map Slope mapping for Chebyshev basis; ignored for power basis @@ -60,9 +60,9 @@ namespace quda @param[in] args Parameter pack of ColorSpinorFields used as temporary passed to Dirac */ template - void Solver::computeCAKrylovSpace(const DiracMatrix &diracm, std::vector &Ap, - std::vector &p, int n_krylov, QudaCABasis basis, double m_map, - double b_map, Args &&...args) + void Solver::computeCAKrylovSpace(const DiracMatrix &diracm, std::vector> &Ap, + std::vector> &p, int n_krylov, + QudaCABasis basis, double m_map, double b_map, Args &&...args) { // in some cases p or Ap may be larger if (static_cast(p.size()) < n_krylov) errorQuda("Invalid p.size() %lu < n_krylov %d", p.size(), n_krylov); From f3a3d8e4fe55afd9becef046ba89f5c4df59fdd7 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 26 Jul 2024 05:06:21 -0700 Subject: [PATCH 036/103] Multigrid solver is now MRHS aware --- include/multigrid.h | 19 ++-- lib/multigrid.cpp | 214 ++++++++++++++++++++++---------------------- 2 files changed, 113 insertions(+), 120 deletions(-) diff --git a/include/multigrid.h b/include/multigrid.h index aa500f6d56..798d1abf4e 100644 --- a/include/multigrid.h +++ b/include/multigrid.h @@ -299,14 +299,14 @@ namespace quda { /** The coarse-grid representation of the null space vectors */ std::vector B_coarse; - /** Residual vector */ - ColorSpinorField r; + /** Residual vector set */ + std::vector r; - /** Coarse residual vector */ - ColorSpinorField r_coarse; + /** Coarse residual vector set */ + std::vector r_coarse; - /** Coarse solution vector */ - ColorSpinorField x_coarse; + /** Coarse solution vector set */ + std::vector x_coarse; /** Kahler-Dirac Xinv */ std::shared_ptr xInvKD; @@ -445,12 +445,7 @@ namespace quda { @param out The solution vector @param in The residual vector (or equivalently the right hand side vector) */ - void operator()(cvector_ref &out, cvector_ref &in) - { - for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); - } - - void operator()(ColorSpinorField &out, const ColorSpinorField &in); + void operator()(cvector_ref &out, cvector_ref &in); /** @brief Load the null space vectors in from file diff --git a/lib/multigrid.cpp b/lib/multigrid.cpp index 9a239e37b2..6f7dc19a50 100644 --- a/lib/multigrid.cpp +++ b/lib/multigrid.cpp @@ -13,7 +13,7 @@ namespace quda using namespace blas; - static bool debug = false; + static constexpr bool debug = false; MG::MG(MGParam ¶m) : Solver(*param.matResidual, *param.matSmooth, *param.matSmoothSloppy, *param.matSmoothSloppy, param), @@ -46,7 +46,7 @@ namespace quda } if (param.B[0].Nspin() == 1) csParam.gammaBasis = param.B[0].GammaBasis(); // hack for staggered to avoid unnecessary basis checks - r = ColorSpinorField(csParam); + resize(r, 1, csParam); } rng = new RNG(param.B[0], 1234); @@ -58,8 +58,8 @@ namespace quda // Initializing to random vectors for (int i = 0; i < (int)param.B.size(); i++) { - spinorNoise(r, *rng, QUDA_NOISE_UNIFORM); - param.B[i] = r; + spinorNoise(r[0], *rng, QUDA_NOISE_UNIFORM); + param.B[i] = r[0]; } } if (param.mg_global.num_setup_iter[param.level] > 0) { @@ -132,14 +132,18 @@ namespace quda param.mg_global.geo_block_size[param.level][i] = param.geoBlockSize[i]; // create coarse residual vector if not already created in verify() - if (r_coarse.empty()) - r_coarse = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r.Precision(), + if (r_coarse.empty()) { + r_coarse.resize(1); + r_coarse[0] = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r[0].Precision(), param.mg_global.location[param.level + 1]); + } // create coarse solution vector if not already created in verify() - if (x_coarse.empty()) - x_coarse = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r.Precision(), + if (x_coarse.empty()) { + x_coarse.resize(1); + x_coarse[0] = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r[0].Precision(), param.mg_global.location[param.level + 1]); + } int nVec_coarse = std::max(param.Nvec, param.mg_global.n_vec[param.level + 1]); B_coarse.resize(nVec_coarse); @@ -660,7 +664,7 @@ namespace quda param_coarse_solver->verbosity_precondition = param.mg_global.verbosity[param.level+1]; // preconditioned solver wrapper is uniform precision - param_coarse_solver->precision = r_coarse.Precision(); + param_coarse_solver->precision = r_coarse[0].Precision(); param_coarse_solver->precision_sloppy = param_coarse_solver->precision; param_coarse_solver->precision_precondition = param_coarse_solver->precision_sloppy; @@ -699,9 +703,9 @@ namespace quda // Run a dummy solve so that the deflation space is constructed and computed if needed during the MG setup, // or the eigenvalues are recomputed during transfer. - spinorNoise(r_coarse, *coarse->rng, QUDA_NOISE_UNIFORM); + spinorNoise(r_coarse[0], *coarse->rng, QUDA_NOISE_UNIFORM); param_coarse_solver->maxiter = 1; // do a single iteration on the dummy solve - (*coarse_solver)(x_coarse, r_coarse); + (*coarse_solver)(x_coarse[0], r_coarse[0]); setOutputPrefix(prefix); // restore since we just popped back from coarse grid param_coarse_solver->maxiter = param.mg_global.coarse_solver_maxiter[param.level + 1]; } @@ -760,8 +764,8 @@ namespace quda { pushLevel(param.level); - QudaPrecision prec = (param.mg_global.precision_null[param.level] < r.Precision()) ? - param.mg_global.precision_null[param.level] : r.Precision(); + QudaPrecision prec = (param.mg_global.precision_null[param.level] < r[0].Precision()) ? + param.mg_global.precision_null[param.level] : r[0].Precision(); // may want to revisit this---these were relaxed for cases where ghost_precision < precision // these were set while hacking in tests of quarter precision ghosts @@ -776,12 +780,12 @@ namespace quda // temporary fields used for verification std::vector fine_tmp(param.Nvec); - ColorSpinorParam fine_param(r); + ColorSpinorParam fine_param(r[0]); fine_param.create = QUDA_NULL_FIELD_CREATE; for (auto &f : fine_tmp) f = ColorSpinorField(fine_param); std::vector coarse_tmp(param.Nvec); - ColorSpinorParam coarse_param(r_coarse); + ColorSpinorParam coarse_param(r_coarse[0]); coarse_param.create = QUDA_NULL_FIELD_CREATE; for (auto &c : coarse_tmp) c = ColorSpinorField(coarse_param); @@ -814,7 +818,7 @@ namespace quda if (check_deviation(max_deviation[i][0], tol)) errorQuda("k=%d orthonormality failed: max deviation %e > %e", i, max_deviation[i][0], tol); } - for (auto &f : fine_tmp) f.GammaBasis(r.GammaBasis()); // restore basis + for (auto &f : fine_tmp) f.GammaBasis(r[0].GammaBasis()); // restore basis // the oblique check if (param.mg_global.run_oblique_proj_check) { @@ -826,10 +830,10 @@ namespace quda logQuda(QUDA_SUMMARIZE, "Checking 1 > || (1 - DP(P^dagDP)P^dag) v_k || / || v_k || for %d vectors\n", param.Nvec); for (int i = 0; i < param.Nvec; i++) { - transfer->R(r_coarse, param.B[i]); - (*coarse_solver)(x_coarse, r_coarse); // this needs to be an exact solve to pass + transfer->R(r_coarse[0], param.B[i]); + (*coarse_solver)(x_coarse[0], r_coarse[0]); // this needs to be an exact solve to pass setOutputPrefix(prefix); // restore prefix after return from coarse grid - transfer->P(tmp2, x_coarse); + transfer->P(tmp2, x_coarse[0]); (*param.matResidual)(tmp1, tmp2); tmp2 = param.B[i]; logQuda(QUDA_SUMMARIZE, "Vector %d: norms %e %e\n", i, B_norm[i], norm2(tmp1)); @@ -847,9 +851,9 @@ namespace quda for (int i=0; iR(r_coarse, param.B[i]); - (*coarse)(x_coarse, r_coarse); // this needs to be an exact solve to pass + (*coarse)(x_coarse[0], r_coarse[0]); // this needs to be an exact solve to pass setOutputPrefix(prefix); // restore output prefix - transfer->P(tmp2, x_coarse); + transfer->P(tmp2, x_coarse[0]); param.matResidual(tmp1, tmp2); tmp2 = param.B[i]; logQuda(QUDA_SUMMARIZE, "Vector %d: norms %e %e ", i, B_norm[i], norm2(tmp1)); @@ -858,25 +862,29 @@ namespace quda #endif // create coarse residual vector if not already created in verify() - if (r_coarse.empty()) - r_coarse = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r.Precision(), + if (r_coarse.empty()) { + r_coarse.resize(1); + r_coarse[0] = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r[0].Precision(), param.mg_global.location[param.level + 1]); + } // create coarse solution vector if not already created in verify() - if (x_coarse.empty()) - x_coarse = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r.Precision(), - param.mg_global.location[param.level + 1]); + if (x_coarse.empty()) { + x_coarse.resize(1); + x_coarse[0] = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r[0].Precision(), + param.mg_global.location[param.level + 1]); + } { logQuda(QUDA_SUMMARIZE, "Checking 0 = (1 - P^\\dagger P) eta_c\n"); - spinorNoise(x_coarse, *rng, QUDA_NOISE_UNIFORM); - transfer->P(tmp2, x_coarse); - transfer->R(r_coarse, tmp2); - auto r2 = norm2(r_coarse); - auto max_deviation = blas::max_deviation(r_coarse, x_coarse); - auto l2_deviation = sqrt(xmyNorm(x_coarse, r_coarse) / norm2(x_coarse)); - logQuda(QUDA_VERBOSE, "L2 norms %e %e (fine tmp %e); Deviations: L2 relative = %e, max = %e\n", norm2(x_coarse), + spinorNoise(x_coarse[0], *rng, QUDA_NOISE_UNIFORM); + transfer->P(tmp2, x_coarse[0]); + transfer->R(r_coarse[0], tmp2); + auto r2 = norm2(r_coarse[0]); + auto max_deviation = blas::max_deviation(r_coarse[0], x_coarse[0]); + auto l2_deviation = sqrt(xmyNorm(x_coarse[0], r_coarse[0]) / norm2(x_coarse[0])); + logQuda(QUDA_VERBOSE, "L2 norms %e %e (fine tmp %e); Deviations: L2 relative = %e, max = %e\n", norm2(x_coarse[0]), r2, norm2(tmp2), l2_deviation, max_deviation[0]); if (check_deviation(l2_deviation, tol)) errorQuda("coarse span failed: L2 relative deviation = %e > %e", l2_deviation, tol); @@ -955,7 +963,7 @@ namespace quda (*param.matResidual)(tmp2, tmp1); } - transfer->R(x_coarse, tmp2); + transfer->R(x_coarse[0], tmp2); static_cast(diracCoarseResidual)->M(r_coarse, tmp_coarse); #if 0 // enable to print out emulated and actual coarse-grid operator vectors for debugging @@ -966,20 +974,20 @@ namespace quda printfQuda("\nemulated\n"); comm_barrier(); for (int parity = 0; parity < 2; parity++) - for (unsigned int x_cb = 0; x_cb < x_coarse.VolumeCB(); x_cb++) x_coarse.PrintVector(parity, x_cb, rank); + for (unsigned int x_cb = 0; x_cb < x_coarse[0].VolumeCB(); x_cb++) x_coarse[0].PrintVector(parity, x_cb, rank); comm_barrier(); printfQuda("\nactual\n"); comm_barrier(); for (int parity = 0; parity < 2; parity++) - for (unsigned int x_cb = 0; x_cb < r_coarse.VolumeCB(); x_cb++) r_coarse.PrintVector(parity, x_cb, rank); + for (unsigned int x_cb = 0; x_cb < r_coarse[0].VolumeCB(); x_cb++) r_coarse[0].PrintVector(parity, x_cb, rank); } setOutputPrefix(prefix); #endif - double r_nrm = norm2(r_coarse); - auto max_deviation = blas::max_deviation(r_coarse, x_coarse); - auto l2_deviation = sqrt(xmyNorm(x_coarse, r_coarse) / norm2(x_coarse)); + double r_nrm = norm2(r_coarse[0]); + auto max_deviation = blas::max_deviation(r_coarse[0], x_coarse[0]); + auto l2_deviation = sqrt(xmyNorm(x_coarse[0], r_coarse[0]) / norm2(x_coarse[0])); if (diracResidual->Mu() != 0.0) { // When the mu is shifted on the coarse level; we can compute exactly the error we introduce in the check: @@ -988,13 +996,13 @@ namespace quda if (fabs(delta_factor) > tol) { double delta_a = delta_factor * 2.0 * diracResidual->Kappa() * diracResidual->Mu() * transfer->Vectors().TwistFlavor(); - l2_deviation -= fabs(delta_a) * sqrt(norm2(tmp_coarse) / norm2(x_coarse)); + l2_deviation -= fabs(delta_a) * sqrt(norm2(tmp_coarse) / norm2(x_coarse[0])); l2_deviation = fabs(l2_deviation); max_deviation[0] -= fabs(delta_a); } } logQuda(QUDA_VERBOSE, "L2 norms: Emulated = %e, Native = %e; Deviations: L2 relative = %e, max = %e\n", - norm2(x_coarse), r_nrm, l2_deviation, max_deviation[0]); + norm2(x_coarse[0]), r_nrm, l2_deviation, max_deviation[0]); if (check_deviation(l2_deviation, tol)) errorQuda("Coarse operator failed: L2 relative deviation = %e > %e", l2_deviation, tol); @@ -1008,14 +1016,14 @@ namespace quda if (coarse_was_preconditioned) { // check eo logQuda(QUDA_SUMMARIZE, "Checking Deo of preconditioned operator 0 = \\hat{D}_c - A^{-1} D_c\n"); - static_cast(diracCoarseResidual)->Dslash(r_coarse.Even(), tmp_coarse.Odd(), QUDA_EVEN_PARITY); - static_cast(diracCoarseResidual)->CloverInv(x_coarse.Even(), r_coarse.Even(), QUDA_EVEN_PARITY); - static_cast(diracCoarseSmoother)->Dslash(r_coarse.Even(), tmp_coarse.Odd(), QUDA_EVEN_PARITY); - double r_nrm = norm2(r_coarse.Even()); - auto max_deviation = blas::max_deviation(r_coarse.Even(), x_coarse.Even()); - auto l2_deviation = sqrt(xmyNorm(x_coarse.Even(), r_coarse.Even()) / norm2(x_coarse.Even())); + static_cast(diracCoarseResidual)->Dslash(r_coarse[0].Even(), tmp_coarse.Odd(), QUDA_EVEN_PARITY); + static_cast(diracCoarseResidual)->CloverInv(x_coarse[0].Even(), r_coarse[0].Even(), QUDA_EVEN_PARITY); + static_cast(diracCoarseSmoother)->Dslash(r_coarse[0].Even(), tmp_coarse.Odd(), QUDA_EVEN_PARITY); + double r_nrm = norm2(r_coarse[0].Even()); + auto max_deviation = blas::max_deviation(r_coarse[0].Even(), x_coarse[0].Even()); + auto l2_deviation = sqrt(xmyNorm(x_coarse[0].Even(), r_coarse[0].Even()) / norm2(x_coarse[0].Even())); logQuda(QUDA_VERBOSE, "L2 norms: Emulated = %e, Native = %e; Deviations: L2 relative = %e, max = %e\n", - norm2(x_coarse.Even()), r_nrm, l2_deviation, max_deviation[0]); + norm2(x_coarse[0].Even()), r_nrm, l2_deviation, max_deviation[0]); if (check_deviation(l2_deviation, tol)) errorQuda("Preconditioned Deo failed: L2 relative deviation = %e > %e", l2_deviation, tol); if (check_deviation(max_deviation[0], tol)) @@ -1023,14 +1031,14 @@ namespace quda // check Doe logQuda(QUDA_SUMMARIZE, "Checking Doe of preconditioned operator 0 = \\hat{D}_c - A^{-1} D_c\n"); - static_cast(diracCoarseResidual)->Dslash(r_coarse.Odd(), tmp_coarse.Even(), QUDA_ODD_PARITY); - static_cast(diracCoarseResidual)->CloverInv(x_coarse.Odd(), r_coarse.Odd(), QUDA_ODD_PARITY); - static_cast(diracCoarseSmoother)->Dslash(r_coarse.Odd(), tmp_coarse.Even(), QUDA_ODD_PARITY); - r_nrm = norm2(r_coarse.Odd()); - max_deviation = blas::max_deviation(r_coarse.Odd(), x_coarse.Odd()); - l2_deviation = sqrt(xmyNorm(x_coarse.Odd(), r_coarse.Odd()) / norm2(x_coarse.Odd())); + static_cast(diracCoarseResidual)->Dslash(r_coarse[0].Odd(), tmp_coarse.Even(), QUDA_ODD_PARITY); + static_cast(diracCoarseResidual)->CloverInv(x_coarse[0].Odd(), r_coarse[0].Odd(), QUDA_ODD_PARITY); + static_cast(diracCoarseSmoother)->Dslash(r_coarse[0].Odd(), tmp_coarse.Even(), QUDA_ODD_PARITY); + r_nrm = norm2(r_coarse[0].Odd()); + max_deviation = blas::max_deviation(r_coarse[0].Odd(), x_coarse[0].Odd()); + l2_deviation = sqrt(xmyNorm(x_coarse[0].Odd(), r_coarse[0].Odd()) / norm2(x_coarse[0].Odd())); logQuda(QUDA_VERBOSE, "L2 norms: Emulated = %e, Native = %e; Deviations: L2 relative = %e, max = %e\n", - norm2(x_coarse.Odd()), r_nrm, l2_deviation, max_deviation[0]); + norm2(x_coarse[0].Odd()), r_nrm, l2_deviation, max_deviation[0]); if (check_deviation(l2_deviation, tol)) errorQuda("Preconditioned Doe failed: L2 relative deviation = %e > %e", l2_deviation, tol); if (check_deviation(max_deviation[0], tol)) @@ -1087,11 +1095,11 @@ namespace quda for (int i = 0; i < param.Nvec; i++) { // Restrict Evec, place result in r_coarse - transfer->R(r_coarse, param.B[i]); + transfer->R(r_coarse[0], param.B[i]); // Prolong r_coarse, place result in tmp2 - transfer->P(tmp2, r_coarse); + transfer->P(tmp2, r_coarse[0]); - printfQuda("Vector %d: norms v_k = %e P^dag v_k = %e PP^dag v_k = %e\n", i, B_norm[i], norm2(r_coarse), + printfQuda("Vector %d: norms v_k = %e P^dag v_k = %e PP^dag v_k = %e\n", i, B_norm[i], norm2(r_coarse[0]), norm2(tmp2)); // Compare v_k and PP^dag v_k. @@ -1108,10 +1116,10 @@ namespace quda // Oblique projections logQuda(QUDA_SUMMARIZE, "Checking 1 > || (1 - DP(P^dagDP)P^dag) v_k || / || v_k || for vector %d\n", i); - transfer->R(r_coarse, param.B[i]); - (*coarse_solver)(x_coarse, r_coarse); // this needs to be an exact solve to pass + transfer->R(r_coarse[0], param.B[i]); + (*coarse_solver)(x_coarse[0], r_coarse[0]); // this needs to be an exact solve to pass setOutputPrefix(prefix); // restore prefix after return from coarse grid - transfer->P(tmp2, x_coarse); + transfer->P(tmp2, x_coarse[0]); (*param.matResidual)(tmp1, tmp2); logQuda(QUDA_SUMMARIZE, "Vector %d: norms v_k %e DP(P^dagDP)P^dag v_k %e\n", i, B_norm[i], norm2(tmp1)); @@ -1132,8 +1140,12 @@ namespace quda popLevel(); } - void MG::operator()(ColorSpinorField &x, const ColorSpinorField &b) + void MG::operator()(cvector_ref &x, cvector_ref &b) { + resize(r, b.size(), QUDA_NULL_FIELD_CREATE); + resize(r_coarse, b.size(), QUDA_NULL_FIELD_CREATE); + resize(x_coarse, b.size(), QUDA_NULL_FIELD_CREATE); + pushOutputPrefix(prefix); if (param.level < param.Nlevel - 1) { // set parity for the solver in the transfer operator @@ -1152,7 +1164,7 @@ namespace quda QudaSolutionType outer_solution_type = b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? QUDA_MAT_SOLUTION : QUDA_MATPC_SOLUTION; QudaSolutionType inner_solution_type = param.coarse_grid_solution_type; - if (debug) printfQuda("outer_solution_type = %d, inner_solution_type = %d\n", outer_solution_type, inner_solution_type); + if constexpr (debug) printfQuda("outer_solution_type = %d, inner_solution_type = %d\n", outer_solution_type, inner_solution_type); if ( outer_solution_type == QUDA_MATPC_SOLUTION && inner_solution_type == QUDA_MAT_SOLUTION) errorQuda("Unsupported solution type combination"); @@ -1160,30 +1172,25 @@ namespace quda if ( inner_solution_type == QUDA_MATPC_SOLUTION && param.smoother_solve_type != QUDA_DIRECT_PC_SOLVE) errorQuda("For this coarse grid solution type, a preconditioned smoother is required"); - if ( debug ) printfQuda("entering V-cycle with x2=%e, r2=%e\n", norm2(x), norm2(b)); + if constexpr ( debug ) for (auto i = 0u; i < b.size(); i++) + printfQuda("entering V-cycle with x2=%e, r2=%e\n", norm2(x[i]), norm2(b[i])); if (param.level < param.Nlevel-1) { //transfer->setTransferGPU(false); // use this to force location of transfer (need to check if still works for multi-level) // do the pre smoothing - if (debug) printfQuda("pre-smoothing b2=%e site subset %d\n", norm2(b), b.SiteSubset()); + if constexpr (debug) for (auto i = 0u; i < b.size(); i++) + printfQuda("pre-smoothing b2=%e site subset %d\n", norm2(b[i]), b.SiteSubset()); - ColorSpinorField out, in; + std::vector out(b.size()), in(b.size()); diracSmoother->prepare(out, in, x, b, outer_solution_type); - ColorSpinorField b_tilde; - // if we're using preconditioning then allocate storage for the preconditioned source vector - if (param.smoother_solve_type == QUDA_DIRECT_PC_SOLVE) { - b_tilde = getFieldTmp(r.Even()); - b_tilde = in; // b_tilde holds either a copy of preconditioned source or a pointer to original source - } - if (presmoother) (*presmoother)(out, in); else zero(out); - ColorSpinorField &solution = inner_solution_type == outer_solution_type ? x : x.Even(); + auto &solution = inner_solution_type == outer_solution_type ? x : x.Even(); diracSmoother->reconstruct(solution, b, inner_solution_type); // if using preconditioned smoother then need to reconstruct full residual @@ -1196,17 +1203,17 @@ namespace quda false; // FIXME this is currently borked if inner solver is preconditioned - const ColorSpinorField &residual = !presmoother ? b : - use_solver_residual ? presmoother->get_residual()[0] : - b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? r : - r.Even(); + const auto &residual = !presmoother ? b : + use_solver_residual ? presmoother->get_residual() : + b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? cvector_ref(r) : + cvector_ref(r).Even(); if (!use_solver_residual && presmoother) { - auto &residual = b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? r : r.Even(); + auto &residual = b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? cvector_ref(r) : cvector_ref(r).Even(); (*param.matResidual)(residual, x); axpby(1.0, b, -1.0, residual); } - double r2 = debug ? norm2(residual) : 0.0; + auto r2 = debug ? norm2(residual) : 0.0; // We need this to ensure that the coarse level has been created. // e.g. in case of iterative setup with MG we use just pre- and post-smoothing at the first iteration. @@ -1214,59 +1221,50 @@ namespace quda // restrict to the coarse grid transfer->R(r_coarse, residual); - if (debug) printfQuda("after pre-smoothing x2 = %e, r2 = %e, r_coarse2 = %e\n", norm2(x), r2, norm2(r_coarse)); + if constexpr (debug) + for (auto i = 0u; i < b.size(); i++) + printfQuda("after pre-smoothing x2 = %e, r2 = %e, r_coarse2 = %e\n", norm2(x[i]), r2[i], norm2(r_coarse[i])); // recurse to the next lower level (*coarse_solver)(x_coarse, r_coarse); - if (debug) printfQuda("after coarse solve x_coarse2 = %e r_coarse2 = %e\n", norm2(x_coarse), norm2(r_coarse)); + if constexpr (debug) + for (auto i = 0u; i < b.size(); i++) + printfQuda("after coarse solve x_coarse2 = %e r_coarse2 = %e\n", norm2(x_coarse[i]), norm2(r_coarse[i])); // prolongate back to this grid - ColorSpinorField &x_coarse_2_fine - = inner_solution_type == QUDA_MAT_SOLUTION ? r : r.Even(); // define according to inner solution type + auto &x_coarse_2_fine + = inner_solution_type == QUDA_MAT_SOLUTION ? cvector_ref(r) : cvector_ref(r).Even(); // define according to inner solution type transfer->P(x_coarse_2_fine, x_coarse); // repurpose residual storage xpy(x_coarse_2_fine, solution); // sum to solution FIXME - sum should be done inside the transfer operator - if ( debug ) { - printfQuda("Prolongated coarse solution y2 = %e\n", norm2(r)); - printfQuda("after coarse-grid correction x2 = %e, r2 = %e\n", norm2(x), norm2(r)); + if constexpr ( debug ) { + for (auto i = 0u; i < b.size(); i++) { + printfQuda("Prolongated coarse solution y2 = %e\n", norm2(r[i])); + printfQuda("after coarse-grid correction x2 = %e, r2 = %e\n", norm2(x[i]), norm2(r[i])); + } } } - if (debug) printfQuda("preparing to post smooth\n"); - - // do the post smoothing - // residual = outer_solution_type == QUDA_MAT_SOLUTION ? r : r.Even(); // refine for outer solution type - if (param.smoother_solve_type == QUDA_DIRECT_PC_SOLVE) { - in = b_tilde.create_alias(); - } else { // this incurs unecessary copying - r = b; - in = r.create_alias(); - } - // we should keep a copy of the prepared right hand side as we've already destroyed it //dirac.prepare(in, out, solution, residual, inner_solution_type); if (postsmoother) (*postsmoother)(out, in); // for inner solve preconditioned, in the should be the original prepared rhs - if (debug) printfQuda("exited postsmooth, about to reconstruct\n"); - diracSmoother->reconstruct(x, b, outer_solution_type); - if (debug) printfQuda("finished reconstruct\n"); - } else { // do the coarse grid solve - ColorSpinorField out, in; + std::vector out(b.size()), in(b.size()); diracSmoother->prepare(out, in, x, b, outer_solution_type); if (presmoother) (*presmoother)(out, in); diracSmoother->reconstruct(x, b, outer_solution_type); } // FIXME on subset check - if (debug && b.SiteSubset() == r.SiteSubset()) { + if (debug && b.SiteSubset() == cvector_ref(r).SiteSubset()) { (*param.matResidual)(r, x); - double r2 = xmyNorm(b, r); - printfQuda("leaving V-cycle with x2=%e, r2=%e\n", norm2(x), r2); + auto r2 = xmyNorm(b, r); + for (auto i = 0u; i < r2.size(); i++) printfQuda("leaving V-cycle with x2=%e, r2=%e\n", norm2(x[i]), r2[i]); } popOutputPrefix(); @@ -1346,7 +1344,7 @@ namespace quda } solverParam.pipeline = (solverParam.inv_type == QUDA_BICGSTAB_INVERTER ? 0 : 4); // FIXME: pipeline != 0 breaks BICGSTAB - solverParam.precision = r.Precision(); + solverParam.precision = r[0].Precision(); if (is_fine_grid()) { solverParam.precision_sloppy = param.mg_global.invert_param->cuda_prec_precondition; @@ -1358,7 +1356,7 @@ namespace quda solverParam.residual_type = static_cast(QUDA_L2_RELATIVE_RESIDUAL); solverParam.compute_null_vector = QUDA_COMPUTE_NULL_VECTOR_YES; ColorSpinorParam csParam(B[0]); // Create spinor field parameters: - csParam.setPrecision(r.Precision(), r.Precision(), true); // ensure native ordering + csParam.setPrecision(r[0].Precision(), r[0].Precision(), true); // ensure native ordering csParam.location = QUDA_CUDA_FIELD_LOCATION; // hard code to GPU location for null-space generation for now csParam.gammaBasis = B[0].Nspin() == 1 ? QUDA_DEGRAND_ROSSI_GAMMA_BASIS : QUDA_UKQCD_GAMMA_BASIS; // degrand-rossi required for staggered From eefe8c776e4e23c77b865e5e4fe2d08da3d9f45d Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 26 Jul 2024 05:42:51 -0700 Subject: [PATCH 037/103] Remove some legacy debug code from multigrid --- lib/multigrid.cpp | 64 ++++++++++++++--------------------------------- 1 file changed, 19 insertions(+), 45 deletions(-) diff --git a/lib/multigrid.cpp b/lib/multigrid.cpp index 6f7dc19a50..9840c420d4 100644 --- a/lib/multigrid.cpp +++ b/lib/multigrid.cpp @@ -135,14 +135,14 @@ namespace quda if (r_coarse.empty()) { r_coarse.resize(1); r_coarse[0] = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r[0].Precision(), - param.mg_global.location[param.level + 1]); + param.mg_global.location[param.level + 1]); } // create coarse solution vector if not already created in verify() if (x_coarse.empty()) { x_coarse.resize(1); x_coarse[0] = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r[0].Precision(), - param.mg_global.location[param.level + 1]); + param.mg_global.location[param.level + 1]); } int nVec_coarse = std::max(param.Nvec, param.mg_global.n_vec[param.level + 1]); @@ -765,7 +765,8 @@ namespace quda pushLevel(param.level); QudaPrecision prec = (param.mg_global.precision_null[param.level] < r[0].Precision()) ? - param.mg_global.precision_null[param.level] : r[0].Precision(); + param.mg_global.precision_null[param.level] : + r[0].Precision(); // may want to revisit this---these were relaxed for cases where ghost_precision < precision // these were set while hacking in tests of quarter precision ghosts @@ -831,7 +832,7 @@ namespace quda for (int i = 0; i < param.Nvec; i++) { transfer->R(r_coarse[0], param.B[i]); - (*coarse_solver)(x_coarse[0], r_coarse[0]); // this needs to be an exact solve to pass + (*coarse_solver)(x_coarse[0], r_coarse[0]); // this needs to be an exact solve to pass setOutputPrefix(prefix); // restore prefix after return from coarse grid transfer->P(tmp2, x_coarse[0]); (*param.matResidual)(tmp1, tmp2); @@ -865,7 +866,7 @@ namespace quda if (r_coarse.empty()) { r_coarse.resize(1); r_coarse[0] = param.B[0].create_coarse(param.geoBlockSize, param.spinBlockSize, param.Nvec, r[0].Precision(), - param.mg_global.location[param.level + 1]); + param.mg_global.location[param.level + 1]); } // create coarse solution vector if not already created in verify() @@ -884,8 +885,8 @@ namespace quda auto r2 = norm2(r_coarse[0]); auto max_deviation = blas::max_deviation(r_coarse[0], x_coarse[0]); auto l2_deviation = sqrt(xmyNorm(x_coarse[0], r_coarse[0]) / norm2(x_coarse[0])); - logQuda(QUDA_VERBOSE, "L2 norms %e %e (fine tmp %e); Deviations: L2 relative = %e, max = %e\n", norm2(x_coarse[0]), - r2, norm2(tmp2), l2_deviation, max_deviation[0]); + logQuda(QUDA_VERBOSE, "L2 norms %e %e (fine tmp %e); Deviations: L2 relative = %e, max = %e\n", + norm2(x_coarse[0]), r2, norm2(tmp2), l2_deviation, max_deviation[0]); if (check_deviation(l2_deviation, tol)) errorQuda("coarse span failed: L2 relative deviation = %e > %e", l2_deviation, tol); if (check_deviation(max_deviation[0], tol)) @@ -1117,7 +1118,7 @@ namespace quda logQuda(QUDA_SUMMARIZE, "Checking 1 > || (1 - DP(P^dagDP)P^dag) v_k || / || v_k || for vector %d\n", i); transfer->R(r_coarse[0], param.B[i]); - (*coarse_solver)(x_coarse[0], r_coarse[0]); // this needs to be an exact solve to pass + (*coarse_solver)(x_coarse[0], r_coarse[0]); // this needs to be an exact solve to pass setOutputPrefix(prefix); // restore prefix after return from coarse grid transfer->P(tmp2, x_coarse[0]); (*param.matResidual)(tmp1, tmp2); @@ -1164,24 +1165,14 @@ namespace quda QudaSolutionType outer_solution_type = b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? QUDA_MAT_SOLUTION : QUDA_MATPC_SOLUTION; QudaSolutionType inner_solution_type = param.coarse_grid_solution_type; - if constexpr (debug) printfQuda("outer_solution_type = %d, inner_solution_type = %d\n", outer_solution_type, inner_solution_type); - if ( outer_solution_type == QUDA_MATPC_SOLUTION && inner_solution_type == QUDA_MAT_SOLUTION) errorQuda("Unsupported solution type combination"); if ( inner_solution_type == QUDA_MATPC_SOLUTION && param.smoother_solve_type != QUDA_DIRECT_PC_SOLVE) errorQuda("For this coarse grid solution type, a preconditioned smoother is required"); - if constexpr ( debug ) for (auto i = 0u; i < b.size(); i++) - printfQuda("entering V-cycle with x2=%e, r2=%e\n", norm2(x[i]), norm2(b[i])); - if (param.level < param.Nlevel-1) { - //transfer->setTransferGPU(false); // use this to force location of transfer (need to check if still works for multi-level) - // do the pre smoothing - if constexpr (debug) for (auto i = 0u; i < b.size(); i++) - printfQuda("pre-smoothing b2=%e site subset %d\n", norm2(b[i]), b.SiteSubset()); - std::vector out(b.size()), in(b.size()); diracSmoother->prepare(out, in, x, b, outer_solution_type); @@ -1203,17 +1194,17 @@ namespace quda false; // FIXME this is currently borked if inner solver is preconditioned - const auto &residual = !presmoother ? b : - use_solver_residual ? presmoother->get_residual() : - b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? cvector_ref(r) : - cvector_ref(r).Even(); + const auto &residual = !presmoother ? b : + use_solver_residual ? presmoother->get_residual() : + b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? cvector_ref(r) : + cvector_ref(r).Even(); if (!use_solver_residual && presmoother) { - auto &residual = b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? cvector_ref(r) : cvector_ref(r).Even(); + auto &residual = b.SiteSubset() == QUDA_FULL_SITE_SUBSET ? cvector_ref(r) : + cvector_ref(r).Even(); (*param.matResidual)(residual, x); axpby(1.0, b, -1.0, residual); } - auto r2 = debug ? norm2(residual) : 0.0; // We need this to ensure that the coarse level has been created. // e.g. in case of iterative setup with MG we use just pre- and post-smoothing at the first iteration. @@ -1221,27 +1212,16 @@ namespace quda // restrict to the coarse grid transfer->R(r_coarse, residual); - if constexpr (debug) - for (auto i = 0u; i < b.size(); i++) - printfQuda("after pre-smoothing x2 = %e, r2 = %e, r_coarse2 = %e\n", norm2(x[i]), r2[i], norm2(r_coarse[i])); // recurse to the next lower level (*coarse_solver)(x_coarse, r_coarse); - if constexpr (debug) - for (auto i = 0u; i < b.size(); i++) - printfQuda("after coarse solve x_coarse2 = %e r_coarse2 = %e\n", norm2(x_coarse[i]), norm2(r_coarse[i])); // prolongate back to this grid - auto &x_coarse_2_fine - = inner_solution_type == QUDA_MAT_SOLUTION ? cvector_ref(r) : cvector_ref(r).Even(); // define according to inner solution type + auto &x_coarse_2_fine = inner_solution_type == QUDA_MAT_SOLUTION ? + cvector_ref(r) : + cvector_ref(r).Even(); // define according to inner solution type transfer->P(x_coarse_2_fine, x_coarse); // repurpose residual storage xpy(x_coarse_2_fine, solution); // sum to solution FIXME - sum should be done inside the transfer operator - if constexpr ( debug ) { - for (auto i = 0u; i < b.size(); i++) { - printfQuda("Prolongated coarse solution y2 = %e\n", norm2(r[i])); - printfQuda("after coarse-grid correction x2 = %e, r2 = %e\n", norm2(x[i]), norm2(r[i])); - } - } } // we should keep a copy of the prepared right hand side as we've already destroyed it @@ -1258,13 +1238,7 @@ namespace quda diracSmoother->prepare(out, in, x, b, outer_solution_type); if (presmoother) (*presmoother)(out, in); diracSmoother->reconstruct(x, b, outer_solution_type); - } - // FIXME on subset check - if (debug && b.SiteSubset() == cvector_ref(r).SiteSubset()) { - (*param.matResidual)(r, x); - auto r2 = xmyNorm(b, r); - for (auto i = 0u; i < r2.size(); i++) printfQuda("leaving V-cycle with x2=%e, r2=%e\n", norm2(x[i]), r2[i]); } popOutputPrefix(); @@ -1356,7 +1330,7 @@ namespace quda solverParam.residual_type = static_cast(QUDA_L2_RELATIVE_RESIDUAL); solverParam.compute_null_vector = QUDA_COMPUTE_NULL_VECTOR_YES; ColorSpinorParam csParam(B[0]); // Create spinor field parameters: - csParam.setPrecision(r[0].Precision(), r[0].Precision(), true); // ensure native ordering + csParam.setPrecision(r[0].Precision(), r[0].Precision(), true); // ensure native ordering csParam.location = QUDA_CUDA_FIELD_LOCATION; // hard code to GPU location for null-space generation for now csParam.gammaBasis = B[0].Nspin() == 1 ? QUDA_DEGRAND_ROSSI_GAMMA_BASIS : QUDA_UKQCD_GAMMA_BASIS; // degrand-rossi required for staggered From f58bc3b3ac706eb4054f832576c9f271dacb8c3a Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Sat, 27 Jul 2024 14:10:10 +0000 Subject: [PATCH 038/103] Add rescaling to coarse dslash with MMA - the code still needs cleanup. --- include/kernels/dslash_coarse_mma.cuh | 57 ++--- include/targets/cuda/mma_tensor_op/gemm.cuh | 199 ++++++++++++++++++ .../cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh | 5 + 3 files changed, 237 insertions(+), 24 deletions(-) diff --git a/include/kernels/dslash_coarse_mma.cuh b/include/kernels/dslash_coarse_mma.cuh index c246fcae13..60cdef5343 100644 --- a/include/kernels/dslash_coarse_mma.cuh +++ b/include/kernels/dslash_coarse_mma.cuh @@ -198,7 +198,8 @@ namespace quda } }; - auto dslash_forward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) { + auto dslash_forward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) -> float { + float rescale_factor; if (forward_exterior[d]) { if constexpr (doHalo()) { constexpr bool a_dagger = false; @@ -213,11 +214,12 @@ namespace quda pipe.consumer_wait(); __syncthreads(); - a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - b_loader.template tmp2s(smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, + float rescale_factor_a = a_loader.template tmp2s_rescale(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + float rescale_factor_b = b_loader.template tmp2s_rescale(smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); + rescale_factor = rescale_factor_a * rescale_factor_b; } } else if constexpr (doBulk()) { @@ -231,16 +233,18 @@ namespace quda pipe.consumer_wait(); __syncthreads(); - a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - b_loader.template tmp2s(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); + float rescale_factor_a = a_loader.template tmp2s_rescale(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + float rescale_factor_b = b_loader.template tmp2s_rescale(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); + rescale_factor = rescale_factor_a * rescale_factor_b; } + return rescale_factor; }; - auto dslash_forward_compute = [&](int d) { + auto dslash_forward_compute = [&](int d, float rescale_factor) { if (forward_exterior[d] && doHalo() || doBulk()) { - accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); + accumulator.mma_rescale(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, rescale_factor); } }; @@ -281,7 +285,8 @@ namespace quda } }; - auto dslash_backward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) { + auto dslash_backward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) -> float { + float rescale_factor; if (backward_exterior[d]) { if constexpr (doHalo()) { constexpr bool a_dagger = true; @@ -296,11 +301,12 @@ namespace quda pipe.consumer_wait(); __syncthreads(); - a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - b_loader.template tmp2s(smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, + float rescale_factor_a = a_loader.template tmp2s_rescale(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + float rescale_factor_b = b_loader.template tmp2s_rescale(smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); + rescale_factor = rescale_factor_a * rescale_factor_b; } } else if constexpr (doBulk()) { constexpr bool a_dagger = true; @@ -313,16 +319,18 @@ namespace quda pipe.consumer_wait(); __syncthreads(); - a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - b_loader.template tmp2s(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); + float rescale_factor_a = a_loader.template tmp2s_rescale(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + float rescale_factor_b = b_loader.template tmp2s_rescale(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); + rescale_factor = rescale_factor_a * rescale_factor_b; } + return rescale_factor; }; - auto dslash_backward_compute = [&](int d) { + auto dslash_backward_compute = [&](int d, float rescale_factor) { if (backward_exterior[d] && doHalo() || doBulk()) { - accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); + accumulator.mma_rescale(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, rescale_factor); } }; @@ -341,7 +349,7 @@ namespace quda pipe.producer_commit(); }; - auto clover_consumer = [&](float scale_inv_a, float scale_inv_b) { + auto clover_consumer = [&](float scale_inv_a, float scale_inv_b) -> float { constexpr bool a_dagger = Arg::dagger; constexpr bool b_dagger = false; @@ -352,13 +360,14 @@ namespace quda pipe.consumer_wait(); __syncthreads(); - a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - b_loader.template tmp2s(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); + float rescale_factor_a = a_loader.template tmp2s_rescale(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + float rescale_factor_b = b_loader.template tmp2s_rescale(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); + return rescale_factor_a * rescale_factor_b; }; - auto clover_compute = [&]() { accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); }; + auto clover_compute = [&](float rescale_factor) { accumulator.mma_rescale(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, rescale_factor); }; float scale_inv_a; float scale_inv_b; @@ -373,19 +382,19 @@ namespace quda #pragma unroll for (int d = 0; d < Arg::nDim; d++) // loop over dimension { - dslash_forward_consumer(d, scale_inv_a, scale_inv_b); + float rescale_factor = dslash_forward_consumer(d, scale_inv_a, scale_inv_b); if (d < 3) { dslash_forward_producer(d + 1, scale_inv_a, scale_inv_b, k_offset); } else { dslash_backward_producer(0, scale_inv_a, scale_inv_b, k_offset); } - dslash_forward_compute(d); + dslash_forward_compute(d, rescale_factor); } // nDim // Backward gather - compute back offset for spinor and gauge fetch #pragma unroll for (int d = 0; d < Arg::nDim; d++) { - dslash_backward_consumer(d, scale_inv_a, scale_inv_b); + float rescale_factor = dslash_backward_consumer(d, scale_inv_a, scale_inv_b); if (d < 3) { dslash_backward_producer(d + 1, scale_inv_a, scale_inv_b, k_offset); } else if (k_offset + Arg::bK < K) { @@ -393,7 +402,7 @@ namespace quda } else if constexpr (doBulk() && Arg::clover) { clover_producer(scale_inv_a, scale_inv_b, 0); } - dslash_backward_compute(d); + dslash_backward_compute(d, rescale_factor); } // nDim } @@ -407,9 +416,9 @@ namespace quda if constexpr (doBulk() && Arg::clover) { if constexpr (!Arg::dslash) { clover_producer(scale_inv_a, scale_inv_b, 0); } for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { - clover_consumer(scale_inv_a, scale_inv_b); + float rescale_factor = clover_consumer(scale_inv_a, scale_inv_b); if (k_offset + Arg::bK < K) { clover_producer(scale_inv_a, scale_inv_b, k_offset + Arg::bK); } - clover_compute(); + clover_compute(rescale_factor); } } diff --git a/include/targets/cuda/mma_tensor_op/gemm.cuh b/include/targets/cuda/mma_tensor_op/gemm.cuh index 9f8eef8241..2185528cef 100644 --- a/include/targets/cuda/mma_tensor_op/gemm.cuh +++ b/include/targets/cuda/mma_tensor_op/gemm.cuh @@ -6,6 +6,8 @@ #include #include +#include + namespace quda { namespace mma @@ -45,6 +47,10 @@ namespace quda reg_imag = 0; } + inline __device__ float abs_max(float a, float max) { + return fmaxf(fabsf(a), max); + } + /** @brief Load from global memory and store data in registers. */ @@ -80,6 +86,80 @@ namespace quda } } + /** + @brief Load from global memory and store data in registers. + */ + template + inline __device__ void convert_x_rescale(float ®_real, float ®_imag, complex *p, int m_idx, int n_idx, + float scale_inv, float rescale) + { + if (x) { + auto xx = p[m_idx * ld + n_idx]; + + if (fixed) { + reg_real = scale_inv * xx.real() * rescale; + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag = scale_inv_conj * xx.imag() * rescale; + } else { + reg_real = +xx.real() * rescale; + reg_imag = (dagger ? -xx.imag() : +xx.imag()) * rescale; + } + } else { + auto xx = p[n_idx * ld + m_idx]; + using store_type = T; + using store_array = typename VectorType::type; + store_array v = *reinterpret_cast(&p[n_idx * ld + m_idx]); + + if (fixed) { + reg_real = scale_inv * xx.real() * rescale; + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag = scale_inv_conj * xx.imag() * rescale; + } else { + reg_real = xx.real() * rescale; + reg_imag = (dagger ? -xx.imag() : xx.imag()) * rescale; + } + } + } + + /** + @brief Load from global memory and store data in registers. + */ + template + inline __device__ float find_abs_max(complex *p, int m_idx, int n_idx, + float scale_inv) + { + float this_max = 0.0f; + + if (x) { + auto xx = p[m_idx * ld + n_idx]; + + if (fixed) { + this_max = abs_max(scale_inv * xx.real(), this_max); + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + this_max = abs_max(scale_inv_conj * xx.imag(), this_max); + } else { + this_max = abs_max(+xx.real(), this_max); + this_max = abs_max(dagger ? -xx.imag() : +xx.imag(), this_max); + } + } else { + auto xx = p[n_idx * ld + m_idx]; + using store_type = T; + using store_array = typename VectorType::type; + store_array v = *reinterpret_cast(&p[n_idx * ld + m_idx]); + + if (fixed) { + this_max = abs_max(scale_inv * xx.real(), this_max); + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + this_max = abs_max(scale_inv_conj * xx.imag(), this_max); + } else { + this_max = abs_max(xx.real(), this_max); + this_max = abs_max(dagger ? -xx.imag() : xx.imag(), this_max); + } + } + + return this_max; + } + /** @brief Load from global memory and store data in registers. */ @@ -196,6 +276,94 @@ namespace quda return gmem.get_scale_inv(); } + template + __device__ inline float tmp2s_rescale(complex *smem_ptr, float scale_inv, smem_accessor_t &smem_real, + smem_accessor_t &smem_imag) + { + // for each iteration, each warp loads a tile + int thread_id = (threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x; + int warp_id = thread_id / 32; + int lane_id = thread_id % 32; + int thread_in_group = lane_id % 4; + int group_id = lane_id / 4; + constexpr int w_m = 8 * batch; + constexpr int w_k = 4; + static_assert(bM % w_m == 0, "bM %% w_m"); + static_assert(bN % w_k == 0, "bN %% w_k"); + + constexpr int tile_dim_m = bM / w_m; + constexpr int tile_dim_k = bN / w_k; + + constexpr int total_tiles = tile_dim_k * tile_dim_m; + constexpr int n_warp = block_y * block_z / 32; + constexpr int warp_cycle = (total_tiles + n_warp - 1) / n_warp; + + float thread_max = 0.0f; + +#pragma unroll + for (int c = 0; c < warp_cycle; c++) { + int logical_warp_index = c * n_warp + warp_id; + if (logical_warp_index < total_tiles) { + int warp_m = (c * n_warp + warp_id) % tile_dim_m; + int warp_k = (c * n_warp + warp_id) / tile_dim_m; + + int smem_m_offset = warp_m * w_m + group_id * batch; + int smem_k_offset = warp_k * w_k + thread_in_group; + + int gmem_m_offset = smem_m_offset; + int gmem_k_offset = smem_k_offset; + + constexpr bool x = (transpose == dagger); + float this_max = find_abs_max(smem_ptr, gmem_m_offset, gmem_k_offset, + scale_inv); + thread_max = fmaxf(this_max, thread_max); + } + } + + __syncthreads(); + // block all-reduce thread_max + using block_reduce_t = cub::BlockReduce; + __shared__ typename block_reduce_t::TempStorage temp_storage; + float block_max = block_reduce_t(temp_storage).Reduce(thread_max, cub::Max()); + + __shared__ float block_max_all; + if (threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z) == 0) { + if (block_max > 0.0f) { + block_max_all = block_max; + } else { + block_max_all = 1.0f; + } + } + __syncthreads(); + float block_rescale_factor = 1e4f / block_max_all; + +#pragma unroll + for (int c = 0; c < warp_cycle; c++) { + int logical_warp_index = c * n_warp + warp_id; + if (logical_warp_index < total_tiles) { + int warp_m = (c * n_warp + warp_id) % tile_dim_m; + int warp_k = (c * n_warp + warp_id) / tile_dim_m; + + int smem_m_offset = warp_m * w_m + group_id * batch; + int smem_k_offset = warp_k * w_k + thread_in_group; + + int gmem_m_offset = smem_m_offset; + int gmem_k_offset = smem_k_offset; + + load_t real; + load_t imag; + + constexpr bool x = (transpose == dagger); + convert_x_rescale(real, imag, smem_ptr, gmem_m_offset, gmem_k_offset, + scale_inv, block_rescale_factor); + smem_real.vector_load(smem_m_offset, smem_k_offset, real); + smem_imag.vector_load(smem_m_offset, smem_k_offset, imag); + } + } + + return 1.0f / block_rescale_factor; + } + template __device__ inline void tmp2s(complex *smem_ptr, float scale_inv, smem_accessor_t &smem_real, smem_accessor_t &smem_imag) @@ -374,6 +542,37 @@ namespace quda } } + template + __device__ inline void mma_rescale(const SmemObjA &smem_obj_a_real, const SmemObjA &smem_obj_a_imag, + const SmemObjB &smem_obj_b_real, const SmemObjB &smem_obj_b_imag, + float rescale) + { + +#pragma unroll + for (int c = 0; c < warp_cycle; c++) { + typename mma_t::OperandC op_c_real_tmp; + op_c_real_tmp.zero(); + typename mma_t::OperandC op_c_imag_tmp; + op_c_imag_tmp.zero(); + +#pragma unroll 1 + for (int tile_k = 0; tile_k < tile_acc_dim; tile_k++) { + + // The logical warp assigned to each part of the matrix. + const int logical_warp_index = wrm.warp_id * warp_cycle + c; + if (logical_warp_index < tile_row_dim * tile_col_dim) { + const int warp_row = logical_warp_index / tile_col_dim; + const int warp_col = logical_warp_index - warp_row * tile_col_dim; + + complex_mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, op_c_real_tmp, + op_c_imag_tmp, warp_row, warp_col, tile_k, wrm); + } + } + op_c_real[c].axpy(rescale, op_c_real_tmp); + op_c_imag[c].axpy(rescale, op_c_imag_tmp); + } + } + template __device__ inline void mma(const SmemObjA &smem_obj_a_real, const SmemObjA &smem_obj_a_imag, const SmemObjB &smem_obj_b_real, const SmemObjB &smem_obj_b_imag) diff --git a/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh b/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh index b7851dc58f..04ec30c8f8 100644 --- a/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh +++ b/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh @@ -150,6 +150,11 @@ namespace quda for (int i = 0; i < 8; i++) { reg[i] *= alpha; } } + __device__ inline void axpy(float alpha, OperandC x) { +#pragma unroll + for (int i = 0; i < 8; i++) { reg[i] += alpha * x.reg[i]; } + } + template __device__ inline void store(void *smem, int warp_row, int warp_col, const WarpRegisterMapping &wrm) { From 958bf12d03c293d56ea99978d21418bc6d8ddb25 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Sun, 28 Jul 2024 00:40:49 -0700 Subject: [PATCH 039/103] Add tensor core support for 32/64 MG coarsening --- lib/coarse_op_mma_launch.h | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/lib/coarse_op_mma_launch.h b/lib/coarse_op_mma_launch.h index e7e89b4ccb..d7453ee704 100644 --- a/lib/coarse_op_mma_launch.h +++ b/lib/coarse_op_mma_launch.h @@ -250,6 +250,24 @@ namespace quda return -1; } + template + std::enable_if_t + launch_compute_uv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const qudaStream_t &stream, Tunable &tunable) + { + if (query_max) return 2; + switch (tp.aux.x) { + // clang-format off + case 0: launch_compute_uv_kernel< 64, 64, 32, 8, 16>(tp, arg, min_threads, stream, tunable); break; + case 1: launch_compute_uv_kernel< 64, 64, 32, 8, 32>(tp, arg, min_threads, stream, tunable); break; + case 2: launch_compute_uv_kernel< 64, 64, 32, 16, 16>(tp, arg, min_threads, stream, tunable); break; + // clang-format on + default: + errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin, + Arg::fineColor, Arg::coarseColor); + } + return -1; + } + // note --- currently unused, may be revisited in the future template std::enable_if_t @@ -326,6 +344,7 @@ namespace quda || (Arg::fineColor == 24 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 24 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 32 && Arg::coarseColor == 32 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) + || (Arg::fineColor == 32 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 64 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 64 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 96 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2)), @@ -554,6 +573,23 @@ namespace quda return -1; } + template + std::enable_if_t + launch_compute_vuv_kernel(TuneParam &tp, const Arg &arg, int min_threads, const qudaStream_t &stream, Tunable &tunable) + { + if (query_max) return 3; + // clang-format off + switch (tp.aux.x) { + case 0: launch_compute_vuv_kernel< 64, 64, 32, 8, 8>(tp, arg, min_threads, stream, tunable); break; + case 1: launch_compute_vuv_kernel< 64, 64, 32, 8, 16>(tp, arg, min_threads, stream, tunable); break; + case 2: launch_compute_vuv_kernel< 64, 64, 32, 16, 8>(tp, arg, min_threads, stream, tunable); break; + case 3: launch_compute_vuv_kernel< 64, 64, 32, 32, 4>(tp, arg, min_threads, stream, tunable); break; + default: errorQuda("tp.aux.x(=%d) is NOT supported by (%d, %d, %d, %d).", tp.aux.x, Arg::fineSpin, Arg::coarseSpin, Arg::fineColor, Arg::coarseColor); + } + // clang-format on + return -1; + } + // note --- currently unused, may be revisited in the future template std::enable_if_t @@ -631,6 +667,7 @@ namespace quda || (Arg::fineColor == 24 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 24 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 32 && Arg::coarseColor == 32 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) + || (Arg::fineColor == 32 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 64 && Arg::coarseColor == 64 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 64 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2) || (Arg::fineColor == 96 && Arg::coarseColor == 96 && Arg::fineSpin == 2 && Arg::coarseSpin == 2)), From 21b4c08faa7b06ec717819bb0d1699e41b10b3ae Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Sun, 28 Jul 2024 00:42:01 -0700 Subject: [PATCH 040/103] Add striped signifier to packing kernel tune key --- lib/dslash_pack2.cu | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/dslash_pack2.cu b/lib/dslash_pack2.cu index 1db387df98..46580ad6f1 100644 --- a/lib/dslash_pack2.cu +++ b/lib/dslash_pack2.cu @@ -53,8 +53,6 @@ namespace quda template class Pack : TunableKernel3D { - -protected: void **ghost; const ColorSpinorField &halo; cvector_ref ∈ @@ -166,8 +164,11 @@ protected: case Device: strcat(aux, ",device-device"); break; case Host: strcat(aux, comm_peer2peer_enabled_global() ? ",host-device" : ",host-host"); break; case Shmem: strcat(aux, ",shmem"); break; - default: errorQuda("Unknown pack target location %d\n", location); + default: errorQuda("Unknown pack target location %d", location); } +#ifdef STRIPED + strcat(aux, ",striped"); +#endif } public: @@ -340,7 +341,7 @@ public: #endif } else { - errorQuda("Unsupported nSpin = %d\n", in.Nspin()); + errorQuda("Unsupported nSpin = %d", in.Nspin()); } } From fce329d413f36c16e1d98724a1f90e47e1572029 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Sun, 28 Jul 2024 16:34:22 -0700 Subject: [PATCH 041/103] Fix multi-RHS deflation --- lib/eigensolve_quda.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/eigensolve_quda.cpp b/lib/eigensolve_quda.cpp index 47e2b1b60e..6efbeaaea9 100644 --- a/lib/eigensolve_quda.cpp +++ b/lib/eigensolve_quda.cpp @@ -624,10 +624,11 @@ namespace quda blas::block::cDotProduct(s, {evecs.begin(), evecs.begin() + n_defl}, {src.begin(), src.end()}); // 2. Perform block caxpy: V_i * (L_i)^{-1} * A_i - for (int i = 0; i < n_defl; i++) { s[i] /= evals[i].real(); } + for (auto j = 0u; j < src.size(); j++) + for (int i = 0; i < n_defl; i++) { s[i * src.size() + j] /= evals[i].real(); } // 3. Accumulate sum vec_defl = Sum_i V_i * (L_i)^{-1} * A_i - if (!accumulate) for (auto &x : sol) blas::zero(x); + if (!accumulate) blas::zero(sol); blas::block::caxpy(s, {evecs.begin(), evecs.begin() + n_defl}, {sol.begin(), sol.end()}); } From d9efb9c5416419efb1eb6d4f91ced06ebbf87d77 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 29 Jul 2024 04:12:29 -0700 Subject: [PATCH 042/103] Augmentation of state reporting to report the power, energy, temperature and clock rate in the QudaInvertParam struct to allow for interface level state reporting --- include/monitor.h | 21 +++++++++++++++ include/quda.h | 20 ++++++-------- include/timer.h | 12 ++++++--- lib/check_params.h | 8 ------ lib/interface_quda.cpp | 42 +++++++++++++---------------- lib/monitor.cpp | 45 ++++++++++++++++++++++++++++++- lib/quda_fortran.F90 | 4 +++ lib/timer.cpp | 61 +++++++++++++++++++++++++++++++++++++++--- 8 files changed, 163 insertions(+), 50 deletions(-) diff --git a/include/monitor.h b/include/monitor.h index 24ac11ee95..1b06d2ffbd 100644 --- a/include/monitor.h +++ b/include/monitor.h @@ -1,3 +1,5 @@ +#include "device.h" + namespace quda { @@ -21,6 +23,25 @@ namespace quda */ void serialize(); + /** + @brief Get the current size of the monitor state. Used for + bookending a period for later analysis. + */ + size_t size(); + + struct state_t { + double energy = 0.0; + double power = 0.0; + double temp = 0.0; + double clock = 0.0; + }; + + /** + @brief Get the mean state observables between start and end, where + start and end are two intervals of history in the state. + */ + state_t mean(size_t start, size_t end); + } // namespace monitor } // namespace quda diff --git a/include/quda.h b/include/quda.h index 2737e192a6..7511271f18 100644 --- a/include/quda.h +++ b/include/quda.h @@ -278,6 +278,10 @@ extern "C" { int iter; /**< The number of iterations performed by the solver */ double gflops; /**< The Gflops rate of the solver */ double secs; /**< The time taken by the solver */ + double energy; /**< The energy consumed by the solver */ + double power; /**< The mean power of the solver */ + double temp; /**< The mean temperature of the device for the duration of the solve */ + double clock; /**< The mean clock frequency of the device for the duration of the solve */ QudaTune tune; /**< Enable auto-tuning? (default = QUDA_TUNE_YES) */ @@ -602,12 +606,6 @@ extern "C" { /** Whether to save eigenvectors in QIO singlefile or partfile format */ QudaBoolean partfile; - /** The Gflops rate of the eigensolver setup */ - double gflops; - - /**< The time taken by the eigensolver setup */ - double secs; - /** Which external library to use in the deflation operations (Eigen) */ QudaExtLibType extlib_type; //------------------------------------------------- @@ -808,12 +806,6 @@ extern "C" { /** Whether to preserve the deflation space during MG update */ QudaBoolean preserve_deflation; - /** The Gflops rate of the multigrid solver setup */ - double gflops; - - /**< The time taken by the multigrid solver setup */ - double secs; - /** Multiplicative factor for the mu parameter */ double mu_factor[QUDA_MAX_MG_LEVEL]; @@ -1822,6 +1814,10 @@ extern "C" { double secs; /** Flops count for the smearing operations **/ double gflops; + double energy; /**< The energy consumed by the smearing operations */ + double power; /**< The mean power of the smearing operations */ + double temp; /**< The mean temperature of the device for the duration of the smearing operations */ + double clock; /**< The mean clock frequency of the device for the duration of the smearing operations */ } QudaQuarkSmearParam; diff --git a/include/timer.h b/include/timer.h index a98928ec39..9cbb58792a 100644 --- a/include/timer.h +++ b/include/timer.h @@ -241,14 +241,20 @@ namespace quda { the profile stack, and be popped when its destructor is called. */ struct pushProfile { - static inline double secs_dummy = 0; - static inline double gflops_dummy = 0; TimeProfile &profile; double &secs; double &gflops; + double &energy; + double &power; + double &temp; + double &clock; uint64_t flops; bool active = false; - pushProfile(TimeProfile &profile, double &secs = secs_dummy, double &gflops = gflops_dummy); + size_t monitor_start; + size_t monitor_end; + + pushProfile(TimeProfile &profile, QudaInvertParam *param = nullptr); + pushProfile(TimeProfile &profile, QudaQuarkSmearParam *param); virtual ~pushProfile(); }; diff --git a/lib/check_params.h b/lib/check_params.h index 8a9f93c81b..13cfcb7593 100644 --- a/lib/check_params.h +++ b/lib/check_params.h @@ -1030,14 +1030,6 @@ void printQudaMultigridParam(QudaMultigridParam *param) { #endif } -#ifdef INIT_PARAM - P(gflops, 0.0); - P(secs, 0.0); -#elif defined(PRINT_PARAM) - P(gflops, INVALID_DOUBLE); - P(secs, INVALID_DOUBLE); -#endif - #ifdef INIT_PARAM P(allow_truncation, QUDA_BOOLEAN_FALSE); #else diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 9de4b4aa50..d178690d2e 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -1384,11 +1384,11 @@ void endQuda(void) assertAllMemFree(); device::destroy(); - - comm_finalize(); - comms_initialized = false; } + comm_finalize(); + comms_initialized = false; + profileInit2End.TPSTOP(QUDA_PROFILE_TOTAL); // print out the profile information of the lifetime of the library @@ -1719,8 +1719,7 @@ namespace quda { void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, QudaParity parity) { - auto profile = pushProfile(profileDslash, inv_param->secs, inv_param->gflops); - + auto profile = pushProfile(profileDslash, inv_param); const auto &gauge = (inv_param->dslash_type != QUDA_ASQTAD_DSLASH) ? *gaugePrecise : *gaugeFatPrecise; if ((!gaugePrecise && inv_param->dslash_type != QUDA_ASQTAD_DSLASH) @@ -1796,8 +1795,7 @@ void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, QudaParity void shiftQuda(void *h_out, void *h_in, int dir, int sym, QudaInvertParam *param) { - auto profile = pushProfile(profileCovDev, param->secs, param->gflops); - + auto profile = pushProfile(profileCovDev, param); const auto &gauge = *gaugePrecise; //(inv_param->dslash_type != QUDA_ASQTAD_DSLASH) ? *gaugePrecise : *gaugeFatPrecise; QudaInvertParam &inv_param = *param; @@ -1866,8 +1864,7 @@ void shiftQuda(void *h_out, void *h_in, int dir, int sym, QudaInvertParam *param void spinTasteQuda(void *h_out, void *h_in, int spin_, int taste, QudaInvertParam *param) { - auto profile = pushProfile(profileCovDev, param->secs, param->gflops); - + auto profile = pushProfile(profileCovDev, param); const auto &gauge = *gaugePrecise; //(inv_param->dslash_type != QUDA_ASQTAD_DSLASH) ? *gaugePrecise : *gaugeFatPrecise; QudaInvertParam &inv_param = *param; @@ -2177,10 +2174,9 @@ void spinTasteQuda(void *h_out, void *h_in, int spin_, int taste, QudaInvertPara void covDevQuda(void *h_out, void *h_in, int dir, QudaInvertParam *param) { - auto profile = pushProfile(profileCovDev, param->secs, param->gflops); + auto profile = pushProfile(profileCovDev, param); QudaInvertParam &inv_param = *param; - const auto &gauge = *gaugePrecise; //(inv_param->dslash_type != QUDA_ASQTAD_DSLASH) ? *gaugePrecise : *gaugeFatPrecise; inv_param.solution_type = QUDA_MAT_SOLUTION; @@ -2544,7 +2540,7 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam // This will define the operator to be eigensolved. QudaInvertParam *inv_param = eig_param->invert_param; - auto profile = pushProfile(profileEigensolve, inv_param->secs, inv_param->gflops); + auto profile = pushProfile(profileEigensolve, inv_param); // QUDA can employ even-odd preconditioning to an operator. // For the eigensolver the solution type must match @@ -2787,7 +2783,7 @@ multigrid_solver::multigrid_solver(QudaMultigridParam &mg_param) void *newMultigridQuda(QudaMultigridParam *mg_param) { profilerStart(__func__); - auto profile = pushProfile(profileInvert, mg_param->secs, mg_param->gflops); + auto profile = pushProfile(profileInvert, mg_param->invert_param); pushVerbosity(mg_param->invert_param->verbosity); auto *mg = new multigrid_solver(*mg_param); @@ -2804,7 +2800,7 @@ void destroyMultigridQuda(void *mg) { void updateMultigridQuda(void *mg_, QudaMultigridParam *mg_param) { profilerStart(__func__); - auto profile = pushProfile(profileInvert, mg_param->secs, mg_param->gflops); + auto profile = pushProfile(profileInvert, mg_param->invert_param); pushVerbosity(mg_param->invert_param->verbosity); profileInvert.TPSTART(QUDA_PROFILE_PREAMBLE); @@ -2913,7 +2909,7 @@ void updateMultigridQuda(void *mg_, QudaMultigridParam *mg_param) void dumpMultigridQuda(void *mg_, QudaMultigridParam *mg_param) { profilerStart(__func__); - auto profile = pushProfile(profileInvert, mg_param->secs, mg_param->gflops); + auto profile = pushProfile(profileInvert, mg_param->invert_param); pushVerbosity(mg_param->invert_param->verbosity); auto *mg = static_cast(mg_); @@ -2988,7 +2984,7 @@ deflated_solver::deflated_solver(QudaEigParam &eig_param, TimeProfile &profile) } void* newDeflationQuda(QudaEigParam *eig_param) { - auto profile = pushProfile(profileInvert, eig_param->secs, eig_param->gflops); + auto profile = pushProfile(profileInvert, eig_param->invert_param); auto *defl = new deflated_solver(*eig_param, profileInvert); saveProfile(__func__); flushProfile(); @@ -3001,7 +2997,7 @@ void destroyDeflationQuda(void *df) { void invertQuda(void *hp_x, void *hp_b, QudaInvertParam *param) { - auto profile = pushProfile(profileInvert, param->secs, param->gflops); + auto profile = pushProfile(profileInvert, param); profilerStart(__func__); if (!initialized) errorQuda("QUDA not initialized"); @@ -3086,7 +3082,7 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col */ profilerStart(__func__); - auto profile = pushProfile(profileInvertMultiSrc, param->secs, param->gflops); + auto profile = pushProfile(profileInvertMultiSrc, param); CommKey split_key = {param->split_grid[0], param->split_grid[1], param->split_grid[2], param->split_grid[3]}; int num_sub_partition = quda::product(split_key); @@ -3424,7 +3420,7 @@ void dslashMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, Quda */ void invertMultiShiftQuda(void **hp_x, void *hp_b, QudaInvertParam *param) { - auto profile = pushProfile(profileMulti, param->secs, param->gflops); + auto profile = pushProfile(profileMulti, param); profilerStart(__func__); if (!initialized) errorQuda("QUDA not initialized"); @@ -4561,7 +4557,7 @@ void computeCloverForceQuda(void *h_mom, double dt, void **h_x, void **, double QudaInvertParam *inv_param) { using namespace quda; - auto profile = pushProfile(profileCloverForce, inv_param->secs, inv_param->gflops); + auto profile = pushProfile(profileCloverForce, inv_param); checkGaugeParam(gauge_param); if (!gaugePrecise) errorQuda("No resident gauge field"); @@ -4633,7 +4629,7 @@ void computeTMCloverForceQuda(void *h_mom, void **h_x, void **h_x0, double *coef QudaGaugeParam *gauge_param, QudaInvertParam *inv_param, int detratio) { using namespace quda; - auto profile = pushProfile(profileTMCloverForce, inv_param->secs, inv_param->gflops); + auto profile = pushProfile(profileTMCloverForce, inv_param); checkGaugeParam(gauge_param); if (!gaugePrecise) errorQuda("No resident gauge field"); @@ -5015,7 +5011,7 @@ void performWuppertalnStep(void *h_out, void *h_in, QudaInvertParam *inv_param, void performTwoLinkGaussianSmearNStep(void *h_in, QudaQuarkSmearParam *smear_param) { if (smear_param->n_steps == 0) return; - auto profile = pushProfile(profileGaussianSmear, smear_param->secs, smear_param->gflops); + auto profile = pushProfile(profileGaussianSmear, smear_param); QudaInvertParam *inv_param = smear_param->inv_param; @@ -5487,7 +5483,7 @@ static void check_param(double _Complex *host_sinks, void **host_quark, int n_qu void laphSinkProject(double _Complex *host_sinks, void **host_quark, int n_quark, int tile_quark, void **host_evec, int n_evec, int tile_evec, QudaInvertParam *inv_param, const int X[4]) { - auto profile = pushProfile(profileSinkProject, inv_param->secs, inv_param->gflops); + auto profile = pushProfile(profileSinkProject, inv_param); // check parameters are valid check_param(host_sinks, host_quark, n_quark, tile_quark, host_evec, n_evec, tile_evec, inv_param, X); diff --git a/lib/monitor.cpp b/lib/monitor.cpp index 972e3edd4f..fd84c4b59b 100644 --- a/lib/monitor.cpp +++ b/lib/monitor.cpp @@ -20,7 +20,7 @@ namespace quda Linked list that we record the evolving state of the device being monitored */ - static std::list state_history; + static std::vector state_history; /** @brief Return the time period for the monitor measurements. @@ -77,6 +77,10 @@ namespace quda while (is_running.load()) { auto state = device::get_state(); state_history.push_back(state); + + // periodically reserve larger state size to avoid push_back cost + if (state_history.size() % 100000 == 0) state_history.reserve(state_history.size() + 100000); + std::this_thread::sleep_for(get_period()); } } @@ -98,6 +102,8 @@ namespace quda if (is_enabled()) { warningQuda("Enabling device monitoring"); + // pre-reserve state_history size to avoid push_back cost + state_history.reserve(10000); start_time = std::chrono::high_resolution_clock::now(); try { // spawn monitoring thread and release @@ -187,5 +193,42 @@ namespace quda monitor_file.close(); } + size_t size() { return is_enabled() ? state_history.size() : 0; } + + state_t mean(size_t start, size_t end) + { + state_t mean; + double last_power = 0.0; + std::chrono::time_point last_time; + + if (start > 0 && end > start) { + auto start_time = state_history[start - 1].time; + auto end_time = state_history[end - 1].time; + for (auto i = start; i < end; i++) { + auto &state = state_history[i]; + if (i - start > 0) { + std::chrono::duration diff = state.time - last_time; + + // potential for non-uniform samples distribution, so integrate rather than sum + mean.power += state.power * diff.count(); + mean.temp += state.temp * diff.count(); + mean.clock += state.clock * diff.count(); + + // trapezoidal integration to compute energy + mean.energy += 0.5 * (state.power + last_power) * diff.count(); + } + last_power = state.power; + last_time = state.time; + } + + std::chrono::duration duration = end_time - start_time; + mean.power /= duration.count(); + mean.temp /= duration.count(); + mean.clock /= duration.count(); + } + + return mean; + } + } // namespace monitor } // namespace quda diff --git a/lib/quda_fortran.F90 b/lib/quda_fortran.F90 index fa46c70726..8eda362fbe 100644 --- a/lib/quda_fortran.F90 +++ b/lib/quda_fortran.F90 @@ -219,6 +219,10 @@ module quda_fortran integer(4) :: iter real(8) :: gflops real(8) :: secs + real(8) :: energy + real(8) :: power + real(8) :: temp + real(8) :: clock ! Enable auto-tuning? QudaTune :: tune diff --git a/lib/timer.cpp b/lib/timer.cpp index 6a122da4fd..154467ea8f 100644 --- a/lib/timer.cpp +++ b/lib/timer.cpp @@ -1,6 +1,7 @@ #include #include #include +#include "monitor.h" #ifdef INTERFACE_NVTX #include "nvtx3/nvToolsExt.h" @@ -240,15 +241,44 @@ namespace quda { static std::stack tp_stack; - pushProfile::pushProfile(TimeProfile &profile, double &secs, double &gflops) : - profile(profile), secs(secs), gflops(gflops), flops(Tunable::flops_global()) + static double double_dummy; + + pushProfile::pushProfile(TimeProfile &profile, QudaInvertParam *param) : + profile(profile), + secs(param ? param->secs : double_dummy), + gflops(param ? param->gflops : double_dummy), + energy(param ? param->energy : double_dummy), + power(param ? param->power : double_dummy), + temp(param ? param->temp : double_dummy), + clock(param ? param->clock : double_dummy), + flops(Tunable::flops_global()) + { + if (profile.Name() != getProfile().Name()) { + // only push to stack if this profile not already the active one + profile.TPSTART(QUDA_PROFILE_TOTAL); + tp_stack.push(&profile); + active = true; + monitor_start = monitor::size(); + } + } + pushProfile::pushProfile(TimeProfile &profile, QudaQuarkSmearParam *param) : + profile(profile), + secs(param ? param->secs : double_dummy), + gflops(param ? param->gflops : double_dummy), + energy(param ? param->energy : double_dummy), + power(param ? param->power : double_dummy), + temp(param ? param->temp : double_dummy), + clock(param ? param->clock : double_dummy), + flops(Tunable::flops_global()) { if (profile.Name() != getProfile().Name()) { // only push to stack if this profile not already the active one profile.TPSTART(QUDA_PROFILE_TOTAL); tp_stack.push(&profile); active = true; + comm_barrier(); + monitor_start = monitor::size(); } } @@ -260,9 +290,34 @@ namespace quda { if (&(this->profile) != &profile) errorQuda("Popped profile is not the expected one"); tp_stack.pop(); profile.TPSTOP(QUDA_PROFILE_TOTAL); + secs = profile.Last(QUDA_PROFILE_TOTAL); + comm_allreduce_max(secs); + gflops = (Tunable::flops_global() - flops) * 1e-9; - if (&gflops != &gflops_dummy) comm_allreduce_sum(gflops); + if (&gflops != &double_dummy) comm_allreduce_sum(gflops); + + // make sure all processes will start + std::vector monitor_start_global = {static_cast(monitor_start)}; + comm_allreduce_min(monitor_start_global); + if (monitor_start_global[0] > 0) { + monitor_end = monitor::size(); + auto mean_state = monitor::mean(monitor_start, monitor_end); + energy = mean_state.energy; + comm_allreduce_sum(energy); + + power = mean_state.power; + comm_allreduce_sum(power); + power /= comm_size(); + + temp = mean_state.temp; + comm_allreduce_sum(temp); + temp /= comm_size(); + + clock = mean_state.clock; + comm_allreduce_sum(clock); + clock /= comm_size(); + } // cache is written out even if a long benchmarking job gets interrupted saveTuneCache(); From a781103a384b1f2aff0d88a18eb73f9999e1411d Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 29 Jul 2024 04:13:07 -0700 Subject: [PATCH 043/103] We should probably use MPI_THREAD_FUNNELED given we have threads now.... --- tests/utils/host_utils.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utils/host_utils.cpp b/tests/utils/host_utils.cpp index 140eccd8a8..a3be4a0d6b 100644 --- a/tests/utils/host_utils.cpp +++ b/tests/utils/host_utils.cpp @@ -410,7 +410,7 @@ void initComms(int, char **, int *const commDims) #if defined(QMP_COMMS) QMP_thread_level_t tl; - QMP_init_msg_passing(&argc, &argv, QMP_THREAD_SINGLE, &tl); + QMP_init_msg_passing(&argc, &argv, QMP_THREAD_FUNNELED, &tl); // make sure the QMP logical ordering matches QUDA's if (rank_order == 0) { @@ -421,7 +421,7 @@ void initComms(int, char **, int *const commDims) QMP_declare_logical_topology_map(commDims, 4, map, 4); } #elif defined(MPI_COMMS) - MPI_Init(&argc, &argv); + MPI_Init_thread(&argc, &argv, MPI_THREAD_FUNNELED); #endif QudaCommsMap func = rank_order == 0 ? lex_rank_from_coords_t : lex_rank_from_coords_x; From 8d3b59e4d73e6688b4c97f4506c3f16682f9488d Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 29 Jul 2024 04:13:32 -0700 Subject: [PATCH 044/103] Report energy when running the solver now --- tests/invert_test.cpp | 17 ++++++++++++++--- tests/staggered_invert_test.cpp | 20 ++++++++++++++++---- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/tests/invert_test.cpp b/tests/invert_test.cpp index 890d357078..9bc95b381b 100644 --- a/tests/invert_test.cpp +++ b/tests/invert_test.cpp @@ -236,7 +236,11 @@ std::vector> solve(test_t param) mg_preconditioner = newMultigridQuda(&mg_param); inv_param.preconditioner = mg_preconditioner; - printfQuda("MG Setup Done: %g secs, %g Gflops\n", mg_param.secs, mg_param.gflops / mg_param.secs); + printfQuda("MG Setup Done: %g secs, %g Gflops\n", mg_param.invert_param->secs, + mg_param.invert_param->gflops / mg_param.invert_param->secs); + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", + mg_param.invert_param->energy, mg_param.invert_param->power, + mg_param.invert_param->temp, mg_param.invert_param->clock); } // Vector construct START @@ -330,6 +334,8 @@ std::vector> solve(test_t param) iter[i] = inv_param.iter; printfQuda("Done: %i iter / %g secs = %g Gflops\n", inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs); + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", + inv_param.energy, inv_param.power, inv_param.temp, inv_param.clock); } } else { @@ -359,8 +365,13 @@ std::vector> solve(test_t param) quda::comm_allreduce_sum(inv_param.gflops); inv_param.gflops /= quda::comm_size() / num_sub_partition; quda::comm_allreduce_max(inv_param.secs); - printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops\n", num_sub_partition, inv_param.iter, - inv_param.secs, inv_param.gflops / inv_param.secs); + printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops, %g secs per source\n", + num_sub_partition, inv_param.iter, + inv_param.secs, inv_param.gflops / inv_param.secs, + inv_param.secs / Nsrc_tile); + printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n", + inv_param.energy, inv_param.energy / Nsrc_tile, + inv_param.power, inv_param.temp, inv_param.clock); } } diff --git a/tests/staggered_invert_test.cpp b/tests/staggered_invert_test.cpp index a557c15797..460fe7de48 100644 --- a/tests/staggered_invert_test.cpp +++ b/tests/staggered_invert_test.cpp @@ -281,7 +281,11 @@ std::vector> solve(test_t param) mg_preconditioner = newMultigridQuda(&mg_param); inv_param.preconditioner = mg_preconditioner; - printfQuda("MG Setup Done: %g secs, %g Gflops\n", mg_param.secs, mg_param.gflops / mg_param.secs); + printfQuda("MG Setup Done: %g secs, %g Gflops\n", mg_param.invert_param->secs, + mg_param.invert_param->gflops / mg_param.invert_param->secs); + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", + mg_param.invert_param->energy, mg_param.invert_param->power, + mg_param.invert_param->temp, mg_param.invert_param->clock); } // Staggered vector construct START @@ -384,8 +388,10 @@ std::vector> solve(test_t param) time[n] = inv_param.secs; gflops[n] = inv_param.gflops / inv_param.secs; iter[n] = inv_param.iter; - printfQuda("Done: %i iter / %g secs = %g Gflops\n\n", inv_param.iter, inv_param.secs, + printfQuda("Done: %i iter / %g secs = %g Gflops\n", inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs); + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n\n", + inv_param.energy, inv_param.power, inv_param.temp, inv_param.clock); } } else { @@ -401,6 +407,8 @@ std::vector> solve(test_t param) _hp_b[i] = in[j + i].data(); } + if (inv_deflate) + eig_param.preserve_deflation = j < Nsrc - Nsrc_tile ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; invertMultiSrcQuda(_hp_x.data(), _hp_b.data(), &inv_param); // move residuals to (i+j)^th location for verification after solves have finished @@ -414,8 +422,12 @@ std::vector> solve(test_t param) quda::comm_allreduce_sum(inv_param.gflops); inv_param.gflops /= comm_size() / num_sub_partition; quda::comm_allreduce_max(inv_param.secs); - printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops\n\n", num_sub_partition, inv_param.iter, - inv_param.secs, inv_param.gflops / inv_param.secs); + printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops, %g secs per source\n", + num_sub_partition, inv_param.iter, + inv_param.secs, inv_param.gflops / inv_param.secs, inv_param.secs / Nsrc_tile); + printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n\n", + inv_param.energy, inv_param.energy / Nsrc_tile, + inv_param.power, inv_param.temp, inv_param.clock); } } From 56ddf51a59ad5d86ac0e4e3a2cc97dcdff89473b Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 29 Jul 2024 04:15:23 -0700 Subject: [PATCH 045/103] Fix Ampere+ mma kernels --- include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh b/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh index e41e089fb2..dc8df91b16 100644 --- a/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh +++ b/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh @@ -314,6 +314,11 @@ namespace quda for (int i = 0; i < warp_m * warp_n * thread_count; i++) { reg[i] *= alpha; } } + __device__ inline void axpy(float alpha, OperandC x) { +#pragma unroll + for (int i = 0; i < warp_m * warp_n * thread_count; i++) { reg[i] += alpha * x.reg[i]; } + } + template __device__ void store(void *ptr, int warp_row, int warp_col, const WarpRegisterMapping &wrm) { // This method is only used for the mobius preconditioner where shuffle_t = half. From 20d33dd38679723c278db8eb2a6c6a514b75b6e0 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 29 Jul 2024 11:13:04 -0700 Subject: [PATCH 046/103] Fix staggered MG bug --- lib/multigrid.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/multigrid.cpp b/lib/multigrid.cpp index 9840c420d4..4f342b9e51 100644 --- a/lib/multigrid.cpp +++ b/lib/multigrid.cpp @@ -794,7 +794,9 @@ namespace quda auto &tmp2 = fine_tmp[1]; auto &tmp_coarse = coarse_tmp[0]; - auto B_norm = norm2(param.B); + cvector B_norm; + if (param.transfer_type == QUDA_TRANSFER_AGGREGATE) B_norm = norm2(param.B); + // No need to check (projector) v_k for staggered case if (param.transfer_type == QUDA_TRANSFER_AGGREGATE) { From 5f8f398ed1a9f9b891594547784edc7d7027bb16 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Mon, 29 Jul 2024 20:20:27 +0000 Subject: [PATCH 047/103] Clean up the coarse dslash MMA code: - Rescaling to the FP16 max of 65504 instead of the arbitrary 1e4. - Only applies rescaling to FP16-ranged MMA types. --- include/kernels/dslash_coarse_mma.cuh | 619 +++++++++++++----- include/targets/cuda/mma_tensor_op/gemm.cuh | 34 +- .../cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh | 8 +- .../cuda/mma_tensor_op/hmma_m16n8k8_sm80.cuh | 5 + .../cuda/mma_tensor_op/hmma_tfloat32_sm80.cuh | 5 + .../cuda/mma_tensor_op/smma_m16n16k4_sm70.cuh | 5 + .../cuda/mma_tensor_op/smma_m16n8_sm80.cuh | 5 + 7 files changed, 485 insertions(+), 196 deletions(-) diff --git a/include/kernels/dslash_coarse_mma.cuh b/include/kernels/dslash_coarse_mma.cuh index 60cdef5343..3c36631842 100644 --- a/include/kernels/dslash_coarse_mma.cuh +++ b/include/kernels/dslash_coarse_mma.cuh @@ -123,6 +123,8 @@ namespace quda using mma_t = typename Arg::mma_t; using Config = mma::MmaConfig; + constexpr bool do_rescale = mma_t::do_rescale(); + static_assert(M % Arg::bM == 0, "M %% Arg::bM != 0.\n"); static_assert(N % Arg::bN == 0, "N %% Arg::bN != 0.\n"); static_assert(K % Arg::bK == 0, "K %% Arg::bK != 0.\n"); @@ -162,266 +164,531 @@ namespace quda backward_idx[d] = linkIndexHop(coord, arg.dim, d, -arg.nFace); } - auto dslash_forward_producer = [&](int d, float &scale_inv_a, float &scale_inv_b, int k_offset) { - const int fwd_idx = forward_idx[d]; + if constexpr (do_rescale) { + auto dslash_forward_producer = [&](int d, float &scale_inv_a, float &scale_inv_b, int k_offset) { + const int fwd_idx = forward_idx[d]; + + if (forward_exterior[d]) { + if constexpr (doHalo()) { + int ghost_idx = ghostFaceIndex<1>(coord, arg.dim, d, arg.nFace); + + auto a = arg.Y(Arg::dagger ? d : d + 4, parity, x_cb, 0, 0); + auto b = arg.halo.Ghost(d, 1, their_spinor_parity, ghost_idx, 0, 0, 0); + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; - if (forward_exterior[d]) { - if constexpr (doHalo()) { - int ghost_idx = ghostFaceIndex<1>(coord, arg.dim, d, arg.nFace); + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b_ghost, pipe); + pipe.producer_commit(); + } + } else if constexpr (doBulk()) { auto a = arg.Y(Arg::dagger ? d : d + 4, parity, x_cb, 0, 0); - auto b = arg.halo.Ghost(d, 1, their_spinor_parity, ghost_idx, 0, 0, 0); + auto b = arg.inA(their_spinor_parity, fwd_idx, 0, 0, 0); constexpr bool a_dagger = false; constexpr bool b_dagger = false; - using store_b_ghost_t = complex; - auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); - __syncthreads(); pipe.producer_acquire(); scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); - scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b_ghost, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); pipe.producer_commit(); } - } else if constexpr (doBulk()) { + }; + + auto dslash_forward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) -> float { + float rescale_factor; + if (forward_exterior[d]) { + if constexpr (doHalo()) { + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; + + using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.halo.Ghost(0, 0, 0, 0, 0, 0, 0)); + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + constexpr bool a_fixed = a_wrapper_t::fixed; + constexpr bool b_fixed = b_wrapper_t::fixed; + + pipe.consumer_wait(); + __syncthreads(); + float rescale_factor_a = a_loader.template tmp2s_rescale( + smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + float rescale_factor_b = b_loader.template tmp2s_rescale( + smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); + pipe.consumer_release(); + __syncthreads(); + rescale_factor = rescale_factor_a * rescale_factor_b; + } + } else if constexpr (doBulk()) { - auto a = arg.Y(Arg::dagger ? d : d + 4, parity, x_cb, 0, 0); - auto b = arg.inA(their_spinor_parity, fwd_idx, 0, 0, 0); - constexpr bool a_dagger = false; - constexpr bool b_dagger = false; + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; - __syncthreads(); - pipe.producer_acquire(); - scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); - scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); - pipe.producer_commit(); - } - }; + using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.inA(0, 0, 0, 0, 0)); + constexpr bool a_fixed = a_wrapper_t::fixed; + constexpr bool b_fixed = b_wrapper_t::fixed; - auto dslash_forward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) -> float { - float rescale_factor; - if (forward_exterior[d]) { - if constexpr (doHalo()) { - constexpr bool a_dagger = false; + pipe.consumer_wait(); + __syncthreads(); + float rescale_factor_a = a_loader.template tmp2s_rescale( + smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + float rescale_factor_b = b_loader.template tmp2s_rescale( + smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); + pipe.consumer_release(); + __syncthreads(); + rescale_factor = rescale_factor_a * rescale_factor_b; + } + return rescale_factor; + }; + + auto dslash_forward_compute = [&](int d, float rescale_factor) { + if (forward_exterior[d] && doHalo() || doBulk()) { + accumulator.mma_rescale(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, rescale_factor); + } + }; + + auto dslash_backward_producer = [&](int d, float &scale_inv_a, float &scale_inv_b, int k_offset) { + const int back_idx = backward_idx[d]; + + if (backward_exterior[d]) { + if constexpr (doHalo()) { + const int ghost_idx = ghostFaceIndex<0>(coord, arg.dim, d, arg.nFace); + + auto a = arg.Y.Ghost(Arg::dagger ? d + 4 : d, 1 - parity, ghost_idx, 0, 0); + auto b = arg.halo.Ghost(d, 0, their_spinor_parity, ghost_idx, 0, 0, 0); + constexpr bool a_dagger = true; + constexpr bool b_dagger = false; + + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b_ghost, pipe); + pipe.producer_commit(); + } + } else if constexpr (doBulk()) { + const int gauge_idx = back_idx; + + auto a = arg.Y(Arg::dagger ? d + 4 : d, 1 - parity, gauge_idx, 0, 0); + auto b = arg.inA(their_spinor_parity, back_idx, 0, 0, 0); + constexpr bool a_dagger = true; + constexpr bool b_dagger = false; + + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); + pipe.producer_commit(); + } + }; + + auto dslash_backward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) -> float { + float rescale_factor; + if (backward_exterior[d]) { + if constexpr (doHalo()) { + constexpr bool a_dagger = true; + constexpr bool b_dagger = false; + + using a_wrapper_t = decltype(arg.Y.Ghost(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.halo.Ghost(0, 0, 0, 0, 0, 0, 0)); + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + constexpr bool a_fixed = a_wrapper_t::fixed; + constexpr bool b_fixed = b_wrapper_t::fixed; + + pipe.consumer_wait(); + __syncthreads(); + float rescale_factor_a = a_loader.template tmp2s_rescale( + smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + float rescale_factor_b = b_loader.template tmp2s_rescale( + smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); + pipe.consumer_release(); + __syncthreads(); + rescale_factor = rescale_factor_a * rescale_factor_b; + } + } else if constexpr (doBulk()) { + constexpr bool a_dagger = true; constexpr bool b_dagger = false; using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); - using b_wrapper_t = decltype(arg.halo.Ghost(0, 0, 0, 0, 0, 0, 0)); - using store_b_ghost_t = complex; - auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + using b_wrapper_t = decltype(arg.inA(0, 0, 0, 0, 0)); constexpr bool a_fixed = a_wrapper_t::fixed; constexpr bool b_fixed = b_wrapper_t::fixed; pipe.consumer_wait(); __syncthreads(); - float rescale_factor_a = a_loader.template tmp2s_rescale(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - float rescale_factor_b = b_loader.template tmp2s_rescale(smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, - smem_obj_b_imag); + float rescale_factor_a = a_loader.template tmp2s_rescale( + smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + float rescale_factor_b = b_loader.template tmp2s_rescale( + smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); rescale_factor = rescale_factor_a * rescale_factor_b; } - } else if constexpr (doBulk()) { + return rescale_factor; + }; + + auto dslash_backward_compute = [&](int d, float rescale_factor) { + if (backward_exterior[d] && doHalo() || doBulk()) { + accumulator.mma_rescale(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, rescale_factor); + } + }; + + auto clover_producer = [&](float &scale_inv_a, float &scale_inv_b, int k_offset) { + const int spinor_parity = (arg.nParity == 2) ? parity : 0; + + auto a = arg.X(0, parity, x_cb, 0, 0); + auto b = arg.inB(spinor_parity, x_cb, 0, 0); + constexpr bool a_dagger = Arg::dagger; + constexpr bool b_dagger = false; + + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); + pipe.producer_commit(); + }; - constexpr bool a_dagger = false; + auto clover_consumer = [&](float scale_inv_a, float scale_inv_b) -> float { + constexpr bool a_dagger = Arg::dagger; constexpr bool b_dagger = false; - using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); - using b_wrapper_t = decltype(arg.inA(0, 0, 0, 0, 0)); + using a_wrapper_t = decltype(arg.X(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.inB(0, 0, 0, 0)); constexpr bool a_fixed = a_wrapper_t::fixed; constexpr bool b_fixed = b_wrapper_t::fixed; pipe.consumer_wait(); __syncthreads(); - float rescale_factor_a = a_loader.template tmp2s_rescale(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - float rescale_factor_b = b_loader.template tmp2s_rescale(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); + float rescale_factor_a = a_loader.template tmp2s_rescale( + smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + float rescale_factor_b = b_loader.template tmp2s_rescale( + smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); - rescale_factor = rescale_factor_a * rescale_factor_b; - } - return rescale_factor; - }; + return rescale_factor_a * rescale_factor_b; + }; - auto dslash_forward_compute = [&](int d, float rescale_factor) { - if (forward_exterior[d] && doHalo() || doBulk()) { + auto clover_compute = [&](float rescale_factor) { accumulator.mma_rescale(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, rescale_factor); + }; + + float scale_inv_a; + float scale_inv_b; + + if constexpr (Arg::dslash) { + + dslash_forward_producer(0, scale_inv_a, scale_inv_b, 0); + + for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { + + // Forward gather - compute fwd offset for spinor fetch +#pragma unroll + for (int d = 0; d < Arg::nDim; d++) // loop over dimension + { + float rescale_factor = dslash_forward_consumer(d, scale_inv_a, scale_inv_b); + if (d < 3) { + dslash_forward_producer(d + 1, scale_inv_a, scale_inv_b, k_offset); + } else { + dslash_backward_producer(0, scale_inv_a, scale_inv_b, k_offset); + } + dslash_forward_compute(d, rescale_factor); + } // nDim + + // Backward gather - compute back offset for spinor and gauge fetch +#pragma unroll + for (int d = 0; d < Arg::nDim; d++) { + float rescale_factor = dslash_backward_consumer(d, scale_inv_a, scale_inv_b); + if (d < 3) { + dslash_backward_producer(d + 1, scale_inv_a, scale_inv_b, k_offset); + } else if (k_offset + Arg::bK < K) { + dslash_forward_producer(0, scale_inv_a, scale_inv_b, k_offset + Arg::bK); + } else if constexpr (doBulk() && Arg::clover) { + clover_producer(scale_inv_a, scale_inv_b, 0); + } + dslash_backward_compute(d, rescale_factor); + } // nDim + } + + accumulator.ax(-arg.kappa); } - }; - auto dslash_backward_producer = [&](int d, float &scale_inv_a, float &scale_inv_b, int k_offset) { - const int back_idx = backward_idx[d]; + /** + Applies the coarse clover matrix on a given parity and + checkerboard site index + */ + if constexpr (doBulk() && Arg::clover) { + if constexpr (!Arg::dslash) { clover_producer(scale_inv_a, scale_inv_b, 0); } + for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { + float rescale_factor = clover_consumer(scale_inv_a, scale_inv_b); + if (k_offset + Arg::bK < K) { clover_producer(scale_inv_a, scale_inv_b, k_offset + Arg::bK); } + clover_compute(rescale_factor); + } + } - if (backward_exterior[d]) { - if constexpr (doHalo()) { - const int ghost_idx = ghostFaceIndex<0>(coord, arg.dim, d, arg.nFace); + } else { - auto a = arg.Y.Ghost(Arg::dagger ? d + 4 : d, 1 - parity, ghost_idx, 0, 0); - auto b = arg.halo.Ghost(d, 0, their_spinor_parity, ghost_idx, 0, 0, 0); - constexpr bool a_dagger = true; - constexpr bool b_dagger = false; + auto dslash_forward_producer = [&](int d, float &scale_inv_a, float &scale_inv_b, int k_offset) { + const int fwd_idx = forward_idx[d]; - using store_b_ghost_t = complex; - auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + if (forward_exterior[d]) { + if constexpr (doHalo()) { + int ghost_idx = ghostFaceIndex<1>(coord, arg.dim, d, arg.nFace); + + auto a = arg.Y(Arg::dagger ? d : d + 4, parity, x_cb, 0, 0); + auto b = arg.halo.Ghost(d, 1, their_spinor_parity, ghost_idx, 0, 0, 0); + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; + + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b_ghost, pipe); + pipe.producer_commit(); + } + } else if constexpr (doBulk()) { + + auto a = arg.Y(Arg::dagger ? d : d + 4, parity, x_cb, 0, 0); + auto b = arg.inA(their_spinor_parity, fwd_idx, 0, 0, 0); + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; __syncthreads(); pipe.producer_acquire(); scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); - scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b_ghost, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); pipe.producer_commit(); } - } else if constexpr (doBulk()) { - const int gauge_idx = back_idx; + }; + + auto dslash_forward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) { + if (forward_exterior[d]) { + if constexpr (doHalo()) { + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; + + using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.halo.Ghost(0, 0, 0, 0, 0, 0, 0)); + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + constexpr bool a_fixed = a_wrapper_t::fixed; + constexpr bool b_fixed = b_wrapper_t::fixed; + + pipe.consumer_wait(); + __syncthreads(); + a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + b_loader.template tmp2s(smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, + smem_obj_b_imag); + pipe.consumer_release(); + __syncthreads(); + } + } else if constexpr (doBulk()) { - auto a = arg.Y(Arg::dagger ? d + 4 : d, 1 - parity, gauge_idx, 0, 0); - auto b = arg.inA(their_spinor_parity, back_idx, 0, 0, 0); - constexpr bool a_dagger = true; - constexpr bool b_dagger = false; + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; - __syncthreads(); - pipe.producer_acquire(); - scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); - scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); - pipe.producer_commit(); - } - }; + using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.inA(0, 0, 0, 0, 0)); + constexpr bool a_fixed = a_wrapper_t::fixed; + constexpr bool b_fixed = b_wrapper_t::fixed; - auto dslash_backward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) -> float { - float rescale_factor; - if (backward_exterior[d]) { - if constexpr (doHalo()) { + pipe.consumer_wait(); + __syncthreads(); + a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + b_loader.template tmp2s(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); + pipe.consumer_release(); + __syncthreads(); + } + }; + + auto dslash_forward_compute = [&](int d) { + if (forward_exterior[d] && doHalo() || doBulk()) { + accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); + } + }; + + auto dslash_backward_producer = [&](int d, float &scale_inv_a, float &scale_inv_b, int k_offset) { + const int back_idx = backward_idx[d]; + + if (backward_exterior[d]) { + if constexpr (doHalo()) { + const int ghost_idx = ghostFaceIndex<0>(coord, arg.dim, d, arg.nFace); + + auto a = arg.Y.Ghost(Arg::dagger ? d + 4 : d, 1 - parity, ghost_idx, 0, 0); + auto b = arg.halo.Ghost(d, 0, their_spinor_parity, ghost_idx, 0, 0, 0); + constexpr bool a_dagger = true; + constexpr bool b_dagger = false; + + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b_ghost, pipe); + pipe.producer_commit(); + } + } else if constexpr (doBulk()) { + const int gauge_idx = back_idx; + + auto a = arg.Y(Arg::dagger ? d + 4 : d, 1 - parity, gauge_idx, 0, 0); + auto b = arg.inA(their_spinor_parity, back_idx, 0, 0, 0); constexpr bool a_dagger = true; constexpr bool b_dagger = false; - using a_wrapper_t = decltype(arg.Y.Ghost(0, 0, 0, 0, 0)); - using b_wrapper_t = decltype(arg.halo.Ghost(0, 0, 0, 0, 0, 0, 0)); - using store_b_ghost_t = complex; - auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); + pipe.producer_commit(); + } + }; + + auto dslash_backward_consumer = [&](int d, float scale_inv_a, float scale_inv_b) { + if (backward_exterior[d]) { + if constexpr (doHalo()) { + constexpr bool a_dagger = true; + constexpr bool b_dagger = false; + + using a_wrapper_t = decltype(arg.Y.Ghost(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.halo.Ghost(0, 0, 0, 0, 0, 0, 0)); + using store_b_ghost_t = complex; + auto smem_tmp_b_ghost = reinterpret_cast(smem_tmp_b); + constexpr bool a_fixed = a_wrapper_t::fixed; + constexpr bool b_fixed = b_wrapper_t::fixed; + + pipe.consumer_wait(); + __syncthreads(); + a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + b_loader.template tmp2s(smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, + smem_obj_b_imag); + pipe.consumer_release(); + __syncthreads(); + } + } else if constexpr (doBulk()) { + constexpr bool a_dagger = true; + constexpr bool b_dagger = false; + + using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.inA(0, 0, 0, 0, 0)); constexpr bool a_fixed = a_wrapper_t::fixed; constexpr bool b_fixed = b_wrapper_t::fixed; pipe.consumer_wait(); __syncthreads(); - float rescale_factor_a = a_loader.template tmp2s_rescale(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - float rescale_factor_b = b_loader.template tmp2s_rescale(smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, - smem_obj_b_imag); + a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + b_loader.template tmp2s(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); - rescale_factor = rescale_factor_a * rescale_factor_b; } - } else if constexpr (doBulk()) { - constexpr bool a_dagger = true; + }; + + auto dslash_backward_compute = [&](int d) { + if (backward_exterior[d] && doHalo() || doBulk()) { + accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); + } + }; + + auto clover_producer = [&](float &scale_inv_a, float &scale_inv_b, int k_offset) { + const int spinor_parity = (arg.nParity == 2) ? parity : 0; + + auto a = arg.X(0, parity, x_cb, 0, 0); + auto b = arg.inB(spinor_parity, x_cb, 0, 0); + constexpr bool a_dagger = Arg::dagger; constexpr bool b_dagger = false; - using a_wrapper_t = decltype(arg.Y(0, 0, 0, 0, 0)); - using b_wrapper_t = decltype(arg.inA(0, 0, 0, 0, 0)); + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); + pipe.producer_commit(); + }; + + auto clover_consumer = [&](float scale_inv_a, float scale_inv_b) { + constexpr bool a_dagger = Arg::dagger; + constexpr bool b_dagger = false; + + using a_wrapper_t = decltype(arg.X(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.inB(0, 0, 0, 0)); constexpr bool a_fixed = a_wrapper_t::fixed; constexpr bool b_fixed = b_wrapper_t::fixed; pipe.consumer_wait(); __syncthreads(); - float rescale_factor_a = a_loader.template tmp2s_rescale(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - float rescale_factor_b = b_loader.template tmp2s_rescale(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); + a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + b_loader.template tmp2s(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); - rescale_factor = rescale_factor_a * rescale_factor_b; - } - return rescale_factor; - }; + }; - auto dslash_backward_compute = [&](int d, float rescale_factor) { - if (backward_exterior[d] && doHalo() || doBulk()) { - accumulator.mma_rescale(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, rescale_factor); - } - }; - - auto clover_producer = [&](float &scale_inv_a, float &scale_inv_b, int k_offset) { - const int spinor_parity = (arg.nParity == 2) ? parity : 0; - - auto a = arg.X(0, parity, x_cb, 0, 0); - auto b = arg.inB(spinor_parity, x_cb, 0, 0); - constexpr bool a_dagger = Arg::dagger; - constexpr bool b_dagger = false; - - __syncthreads(); - pipe.producer_acquire(); - scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); - scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); - pipe.producer_commit(); - }; - - auto clover_consumer = [&](float scale_inv_a, float scale_inv_b) -> float { - constexpr bool a_dagger = Arg::dagger; - constexpr bool b_dagger = false; - - using a_wrapper_t = decltype(arg.X(0, 0, 0, 0, 0)); - using b_wrapper_t = decltype(arg.inB(0, 0, 0, 0)); - constexpr bool a_fixed = a_wrapper_t::fixed; - constexpr bool b_fixed = b_wrapper_t::fixed; - - pipe.consumer_wait(); - __syncthreads(); - float rescale_factor_a = a_loader.template tmp2s_rescale(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - float rescale_factor_b = b_loader.template tmp2s_rescale(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); - pipe.consumer_release(); - __syncthreads(); - return rescale_factor_a * rescale_factor_b; - }; - - auto clover_compute = [&](float rescale_factor) { accumulator.mma_rescale(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, rescale_factor); }; + auto clover_compute + = [&]() { accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); }; - float scale_inv_a; - float scale_inv_b; + float scale_inv_a; + float scale_inv_b; - if constexpr (Arg::dslash) { + if constexpr (Arg::dslash) { - dslash_forward_producer(0, scale_inv_a, scale_inv_b, 0); + dslash_forward_producer(0, scale_inv_a, scale_inv_b, 0); - for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { + for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { - // Forward gather - compute fwd offset for spinor fetch + // Forward gather - compute fwd offset for spinor fetch #pragma unroll - for (int d = 0; d < Arg::nDim; d++) // loop over dimension - { - float rescale_factor = dslash_forward_consumer(d, scale_inv_a, scale_inv_b); - if (d < 3) { - dslash_forward_producer(d + 1, scale_inv_a, scale_inv_b, k_offset); - } else { - dslash_backward_producer(0, scale_inv_a, scale_inv_b, k_offset); - } - dslash_forward_compute(d, rescale_factor); - } // nDim - - // Backward gather - compute back offset for spinor and gauge fetch + for (int d = 0; d < Arg::nDim; d++) // loop over dimension + { + dslash_forward_consumer(d, scale_inv_a, scale_inv_b); + if (d < 3) { + dslash_forward_producer(d + 1, scale_inv_a, scale_inv_b, k_offset); + } else { + dslash_backward_producer(0, scale_inv_a, scale_inv_b, k_offset); + } + dslash_forward_compute(d); + } // nDim + + // Backward gather - compute back offset for spinor and gauge fetch #pragma unroll - for (int d = 0; d < Arg::nDim; d++) { - float rescale_factor = dslash_backward_consumer(d, scale_inv_a, scale_inv_b); - if (d < 3) { - dslash_backward_producer(d + 1, scale_inv_a, scale_inv_b, k_offset); - } else if (k_offset + Arg::bK < K) { - dslash_forward_producer(0, scale_inv_a, scale_inv_b, k_offset + Arg::bK); - } else if constexpr (doBulk() && Arg::clover) { - clover_producer(scale_inv_a, scale_inv_b, 0); - } - dslash_backward_compute(d, rescale_factor); - } // nDim - } + for (int d = 0; d < Arg::nDim; d++) { + dslash_backward_consumer(d, scale_inv_a, scale_inv_b); + if (d < 3) { + dslash_backward_producer(d + 1, scale_inv_a, scale_inv_b, k_offset); + } else if (k_offset + Arg::bK < K) { + dslash_forward_producer(0, scale_inv_a, scale_inv_b, k_offset + Arg::bK); + } else if constexpr (doBulk() && Arg::clover) { + clover_producer(scale_inv_a, scale_inv_b, 0); + } + dslash_backward_compute(d); + } // nDim + } - accumulator.ax(-arg.kappa); - } + accumulator.ax(-arg.kappa); + } - /** - Applies the coarse clover matrix on a given parity and - checkerboard site index - */ - if constexpr (doBulk() && Arg::clover) { - if constexpr (!Arg::dslash) { clover_producer(scale_inv_a, scale_inv_b, 0); } - for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { - float rescale_factor = clover_consumer(scale_inv_a, scale_inv_b); - if (k_offset + Arg::bK < K) { clover_producer(scale_inv_a, scale_inv_b, k_offset + Arg::bK); } - clover_compute(rescale_factor); + /** + Applies the coarse clover matrix on a given parity and + checkerboard site index + */ + if constexpr (doBulk() && Arg::clover) { + if constexpr (!Arg::dslash) { clover_producer(scale_inv_a, scale_inv_b, 0); } + for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { + clover_consumer(scale_inv_a, scale_inv_b); + if (k_offset + Arg::bK < K) { clover_producer(scale_inv_a, scale_inv_b, k_offset + Arg::bK); } + clover_compute(); + } } } - return accumulator; } diff --git a/include/targets/cuda/mma_tensor_op/gemm.cuh b/include/targets/cuda/mma_tensor_op/gemm.cuh index 2185528cef..334675d495 100644 --- a/include/targets/cuda/mma_tensor_op/gemm.cuh +++ b/include/targets/cuda/mma_tensor_op/gemm.cuh @@ -47,9 +47,7 @@ namespace quda reg_imag = 0; } - inline __device__ float abs_max(float a, float max) { - return fmaxf(fabsf(a), max); - } + inline __device__ float abs_max(float a, float max) { return fmaxf(fabsf(a), max); } /** @brief Load from global memory and store data in registers. @@ -87,7 +85,7 @@ namespace quda } /** - @brief Load from global memory and store data in registers. + @brief Load from global memory and store data in registers while also applying a rescaling */ template inline __device__ void convert_x_rescale(float ®_real, float ®_imag, complex *p, int m_idx, int n_idx, @@ -125,8 +123,7 @@ namespace quda @brief Load from global memory and store data in registers. */ template - inline __device__ float find_abs_max(complex *p, int m_idx, int n_idx, - float scale_inv) + inline __device__ float find_abs_max(complex *p, int m_idx, int n_idx, float scale_inv) { float this_max = 0.0f; @@ -135,11 +132,10 @@ namespace quda if (fixed) { this_max = abs_max(scale_inv * xx.real(), this_max); - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - this_max = abs_max(scale_inv_conj * xx.imag(), this_max); + this_max = abs_max(scale_inv * xx.imag(), this_max); } else { - this_max = abs_max(+xx.real(), this_max); - this_max = abs_max(dagger ? -xx.imag() : +xx.imag(), this_max); + this_max = abs_max(xx.real(), this_max); + this_max = abs_max(xx.imag(), this_max); } } else { auto xx = p[n_idx * ld + m_idx]; @@ -149,11 +145,10 @@ namespace quda if (fixed) { this_max = abs_max(scale_inv * xx.real(), this_max); - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - this_max = abs_max(scale_inv_conj * xx.imag(), this_max); + this_max = abs_max(scale_inv * xx.imag(), this_max); } else { this_max = abs_max(xx.real(), this_max); - this_max = abs_max(dagger ? -xx.imag() : xx.imag(), this_max); + this_max = abs_max(xx.imag(), this_max); } } @@ -314,13 +309,12 @@ namespace quda int gmem_k_offset = smem_k_offset; constexpr bool x = (transpose == dagger); - float this_max = find_abs_max(smem_ptr, gmem_m_offset, gmem_k_offset, - scale_inv); + float this_max = find_abs_max < x, fixed, dagger, + x ? bN + 4 : bM + 4 > (smem_ptr, gmem_m_offset, gmem_k_offset, scale_inv); thread_max = fmaxf(this_max, thread_max); } } - __syncthreads(); // block all-reduce thread_max using block_reduce_t = cub::BlockReduce; __shared__ typename block_reduce_t::TempStorage temp_storage; @@ -335,7 +329,8 @@ namespace quda } } __syncthreads(); - float block_rescale_factor = 1e4f / block_max_all; + + float block_rescale_factor = 65504.0f / block_max_all; // 65504 = the maximum FP16 number #pragma unroll for (int c = 0; c < warp_cycle; c++) { @@ -542,10 +537,10 @@ namespace quda } } + /** @brief Apply MMA, but doing a rescaling before accumulate into the final accumulator */ template __device__ inline void mma_rescale(const SmemObjA &smem_obj_a_real, const SmemObjA &smem_obj_a_imag, - const SmemObjB &smem_obj_b_real, const SmemObjB &smem_obj_b_imag, - float rescale) + const SmemObjB &smem_obj_b_real, const SmemObjB &smem_obj_b_imag, float rescale) { #pragma unroll @@ -573,6 +568,7 @@ namespace quda } } + /** @brief Apply MMA */ template __device__ inline void mma(const SmemObjA &smem_obj_a_real, const SmemObjA &smem_obj_a_imag, const SmemObjB &smem_obj_b_real, const SmemObjB &smem_obj_b_imag) diff --git a/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh b/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh index 04ec30c8f8..39d99bc93a 100644 --- a/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh +++ b/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh @@ -22,6 +22,11 @@ namespace quda static __device__ __host__ constexpr int inline pad_size(int m) { return m == 48 ? 2 : 10; } + static constexpr bool do_rescale() + { + return true; // true because we use FP16 + } + static constexpr int MMA_M = 16; static constexpr int MMA_N = 16; static constexpr int MMA_K = 4; @@ -150,7 +155,8 @@ namespace quda for (int i = 0; i < 8; i++) { reg[i] *= alpha; } } - __device__ inline void axpy(float alpha, OperandC x) { + __device__ inline void axpy(float alpha, OperandC x) + { #pragma unroll for (int i = 0; i < 8; i++) { reg[i] += alpha * x.reg[i]; } } diff --git a/include/targets/cuda/mma_tensor_op/hmma_m16n8k8_sm80.cuh b/include/targets/cuda/mma_tensor_op/hmma_m16n8k8_sm80.cuh index b8e95a093c..2bdd83c74b 100644 --- a/include/targets/cuda/mma_tensor_op/hmma_m16n8k8_sm80.cuh +++ b/include/targets/cuda/mma_tensor_op/hmma_m16n8k8_sm80.cuh @@ -16,6 +16,11 @@ namespace quda static __device__ __host__ constexpr int inline pad_size(int m) { return m == 192 ? 0 : 8; } + static constexpr bool do_rescale() + { + return true; // true because we use FP16 + } + static constexpr int MMA_M = 16; static constexpr int MMA_N = 8; static constexpr int MMA_K = 8; diff --git a/include/targets/cuda/mma_tensor_op/hmma_tfloat32_sm80.cuh b/include/targets/cuda/mma_tensor_op/hmma_tfloat32_sm80.cuh index 5fb40699af..42ddc42c4b 100644 --- a/include/targets/cuda/mma_tensor_op/hmma_tfloat32_sm80.cuh +++ b/include/targets/cuda/mma_tensor_op/hmma_tfloat32_sm80.cuh @@ -11,6 +11,11 @@ namespace quda template struct hmma_tfloat32_t { + static constexpr bool do_rescale() + { + return false; // false because TF32 has the same range as FP32 + } + static constexpr int warp_m = warp_m_; static constexpr int warp_n = warp_n_; diff --git a/include/targets/cuda/mma_tensor_op/smma_m16n16k4_sm70.cuh b/include/targets/cuda/mma_tensor_op/smma_m16n16k4_sm70.cuh index 2d8fad7e90..c72bceae23 100644 --- a/include/targets/cuda/mma_tensor_op/smma_m16n16k4_sm70.cuh +++ b/include/targets/cuda/mma_tensor_op/smma_m16n16k4_sm70.cuh @@ -17,6 +17,11 @@ namespace quda static __device__ __host__ constexpr int inline pad_size(int) { return 0; } + static constexpr bool do_rescale() + { + return true; // false because we use FP16 + } + static constexpr int MMA_M = 16; static constexpr int MMA_N = 16; static constexpr int MMA_K = 4; diff --git a/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh b/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh index e41e089fb2..f6b2fd350f 100644 --- a/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh +++ b/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh @@ -91,6 +91,11 @@ namespace quda static constexpr bool use_intermediate_accumulator() { return true; }; + static constexpr bool do_rescale() + { + return std::is_same_v ? true : false; // true if we use FP16 + } + static constexpr int warp_m = warp_m_; static constexpr int warp_n = warp_n_; From 1b220b717fd785d82486d3a54a411d966f844729 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 29 Jul 2024 14:36:10 -0700 Subject: [PATCH 048/103] cvector -> vector --- lib/multigrid.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/multigrid.cpp b/lib/multigrid.cpp index 4f342b9e51..0fa224a8fd 100644 --- a/lib/multigrid.cpp +++ b/lib/multigrid.cpp @@ -794,7 +794,7 @@ namespace quda auto &tmp2 = fine_tmp[1]; auto &tmp_coarse = coarse_tmp[0]; - cvector B_norm; + vector B_norm; if (param.transfer_type == QUDA_TRANSFER_AGGREGATE) B_norm = norm2(param.B); // No need to check (projector) v_k for staggered case From 7af1adeff0f5f09d93b5f2b3818b8590de529493 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 30 Jul 2024 16:47:55 -0700 Subject: [PATCH 049/103] Fix MPI bug --- tests/utils/host_utils.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/utils/host_utils.cpp b/tests/utils/host_utils.cpp index a3be4a0d6b..c7b0808180 100644 --- a/tests/utils/host_utils.cpp +++ b/tests/utils/host_utils.cpp @@ -421,7 +421,15 @@ void initComms(int, char **, int *const commDims) QMP_declare_logical_topology_map(commDims, 4, map, 4); } #elif defined(MPI_COMMS) - MPI_Init_thread(&argc, &argv, MPI_THREAD_FUNNELED); + int provided = 0; + int required = MPI_THREAD_FUNNELED; + int flag = MPI_Init_thread(&argc, &argv, required, &provided); + + if (provided != required) { + printf("%s: required thread-safety level %d can't be provided %d\n", __func__, required, provided); + fflush(stdout); + exit(flag); + } #endif QudaCommsMap func = rank_order == 0 ? lex_rank_from_coords_t : lex_rank_from_coords_x; From 414260a83184a42adf25b56fb9a20ecd11cef9de Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 31 Jul 2024 03:53:25 -0700 Subject: [PATCH 050/103] Fix deflateSVD for block deflation --- lib/eigensolve_quda.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/eigensolve_quda.cpp b/lib/eigensolve_quda.cpp index 6efbeaaea9..9314afafe6 100644 --- a/lib/eigensolve_quda.cpp +++ b/lib/eigensolve_quda.cpp @@ -563,9 +563,11 @@ namespace quda // 2. Perform block caxpy // A_i -> (\sigma_i)^{-1} * A_i // vec_defl = Sum_i (R_i)^{-1} * A_i - if (!accumulate) for (auto &x : sol) blas::zero(x); - for (int i = 0; i < n_defl; i++) s[i] /= evals[i].real(); + for (auto j = 0u; j < src.size(); j++) + for (int i = 0; i < n_defl; i++) { s[i * src.size() + j] /= evals[i].real(); } + // 3. Accumulate sum vec_defl = Sum_i V_i * (L_i)^{-1} * A_i + if (!accumulate) blas::zero(sol); blas::block::caxpy(s, {evecs.begin(), evecs.begin() + n_defl}, {sol.begin(), sol.end()}); } @@ -629,7 +631,6 @@ namespace quda // 3. Accumulate sum vec_defl = Sum_i V_i * (L_i)^{-1} * A_i if (!accumulate) blas::zero(sol); - blas::block::caxpy(s, {evecs.begin(), evecs.begin() + n_defl}, {sol.begin(), sol.end()}); } From 181a52ebcc8a8dcea146fac2364803d25f7a642e Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 27 Aug 2024 17:23:20 -0700 Subject: [PATCH 051/103] Add some sanity checking when using split grid --- lib/check_params.h | 11 +++++++++++ tests/utils/set_params.cpp | 8 +++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/lib/check_params.h b/lib/check_params.h index 13cfcb7593..e692bd1f1e 100644 --- a/lib/check_params.h +++ b/lib/check_params.h @@ -450,6 +450,17 @@ void printQudaInvertParam(QudaInvertParam *param) { #else for (int d = 0; d < 4; d++) { P(split_grid[d], INVALID_INT); } /**< Grid of sub-partitions */ P(num_src_per_sub_partition, INVALID_INT); /**< Number of sources per sub-partitions */ +#ifdef CHECK_PARAM + int split_grid_size = 1; + for (int d = 0; d < 4; d++) split_grid_size *= param->split_grid[d]; + if (split_grid_size > 1) { + if (param->num_src_per_sub_partition < 1) + errorQuda("Invalid num_src_per_subpartition = %d", param->num_src_per_sub_partition); + if (param->num_src % param->num_src_per_sub_partition != 0) + errorQuda("num_src %d not compatible with num_src_per_sub_partition %d", + param->num_src, param->num_src_per_sub_partition); + } +#endif #endif #ifdef INIT_PARAM diff --git a/tests/utils/set_params.cpp b/tests/utils/set_params.cpp index a9e6fe2dc0..6c9f10c597 100644 --- a/tests/utils/set_params.cpp +++ b/tests/utils/set_params.cpp @@ -142,6 +142,9 @@ void setInvertParam(QudaInvertParam &inv_param) // Use 3D or 4D laplace inv_param.laplace3D = laplace3D; + if (Nsrc < Nsrc_tile || Nsrc % Nsrc_tile != 0) + errorQuda("Invalid combination Nsrc = %d Nsrc_tile = %d", Nsrc, Nsrc_tile); + // Some fermion specific parameters if (dslash_type == QUDA_TWISTED_MASS_DSLASH || dslash_type == QUDA_TWISTED_CLOVER_DSLASH) { inv_param.mu = mu; @@ -932,6 +935,9 @@ void setStaggeredInvertParam(QudaInvertParam &inv_param) inv_param.kappa = kappa = 1.0 / (8.0 + mass); // for Laplace operator inv_param.laplace3D = laplace3D; // for Laplace operator + if (Nsrc < Nsrc_tile || Nsrc % Nsrc_tile != 0) + errorQuda("Invalid combination Nsrc = %d Nsrc_tile = %d", Nsrc, Nsrc_tile); + // outer solver parameters inv_param.inv_type = inv_type; inv_param.tol = tol; @@ -943,7 +949,7 @@ void setStaggeredInvertParam(QudaInvertParam &inv_param) inv_param.solution_accumulator_pipeline = solution_accumulator_pipeline; inv_param.pipeline = pipeline; - inv_param.Ls = 1; // Nsrc + inv_param.Ls = 1; if (tol_hq == 0 && tol == 0) { errorQuda("qudaInvert: requesting zero residual\n"); From 4447320cdfb1034c656cb96b1dc4e6596f5b0dd9 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 28 Aug 2024 16:33:31 -0700 Subject: [PATCH 052/103] If communicator is not found, do not call errorQuda (which causes an infinite recursion) --- lib/communicator_stack.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/lib/communicator_stack.cpp b/lib/communicator_stack.cpp index 310184d736..1ff58cb4e8 100644 --- a/lib/communicator_stack.cpp +++ b/lib/communicator_stack.cpp @@ -27,14 +27,22 @@ namespace quda static Communicator &get_default_communicator() { auto search = communicator_stack.find(default_comm_key); - if (search == communicator_stack.end()) { errorQuda("Default communicator can't be found."); } + if (search == communicator_stack.end()) { + fprintf(getOutputFile(), "Current communicator can't be found\n"); + fflush(getOutputFile()); + comm_abort(1); + } return search->second; } Communicator &get_current_communicator() { auto search = communicator_stack.find(current_key); - if (search == communicator_stack.end()) { errorQuda("Current communicator can't be found."); } + if (search == communicator_stack.end()) { + fprintf(getOutputFile(), "Current communicator can't be found\n"); + fflush(getOutputFile()); + comm_abort(1); + } return search->second; } From 994bdd8875655ae7166556bb4ffbe277f387ab11 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Thu, 29 Aug 2024 13:14:02 -0700 Subject: [PATCH 053/103] Fix some verbosity aspects of tuning --- lib/tune.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lib/tune.cpp b/lib/tune.cpp index cdfa24660b..3bcd7da964 100644 --- a/lib/tune.cpp +++ b/lib/tune.cpp @@ -105,7 +105,7 @@ namespace quda const std::string get_resource_path() { - static std::string resource_path; + static std::string resource_path = {}; static bool init = false; if (!init) { @@ -114,10 +114,8 @@ namespace quda if (!path) { warningQuda("Environment variable QUDA_RESOURCE_PATH is not set."); - return {}; } else if (stat(path, &pstat) || !S_ISDIR(pstat.st_mode)) { warningQuda("The path \"%s\" specified by QUDA_RESOURCE_PATH does not exist or is not a directory.", path); - return {}; } else { resource_path = path; } @@ -440,7 +438,11 @@ namespace quda auto &resource_path = get_resource_path(); if (resource_path.empty()) { - warningQuda("Caching of tuned parameters will be disabled"); + static bool init = false; + if (!init) { + warningQuda("Caching of tuned parameters will be disabled"); + init = true; + } return; } From 35e17db5d20af835dcc6800f70d14ac09ec2e6da Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 3 Sep 2024 16:20:58 -0700 Subject: [PATCH 054/103] Check set sizes match when copying between them --- include/blas_quda.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/blas_quda.h b/include/blas_quda.h index 907f23945b..77b3c404f3 100644 --- a/include/blas_quda.h +++ b/include/blas_quda.h @@ -30,6 +30,7 @@ namespace quda { inline void copy(cvector_ref &dst, cvector_ref &src) { + if (dst.size() != src.size()) errorQuda("Mismatched vector sets %lu != %lu", dst.size(), src.size()); for (auto i = 0u; i < src.size(); i++) { dst[i].copy(src[i]); } } From 661a2a1133793e7897a3352a38803fa5faf89af2 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 3 Sep 2024 16:27:10 -0700 Subject: [PATCH 055/103] Multi-RHS solvers should check to see if their state needs to be resized --- lib/inv_ca_gcr.cpp | 2 +- lib/inv_cg_quda.cpp | 2 +- lib/inv_gcr_quda.cpp | 2 +- lib/inv_mr_quda.cpp | 2 +- lib/inv_sd_quda.cpp | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/inv_ca_gcr.cpp b/lib/inv_ca_gcr.cpp index 1b5fdb372e..07fdd9fe0d 100644 --- a/lib/inv_ca_gcr.cpp +++ b/lib/inv_ca_gcr.cpp @@ -25,7 +25,7 @@ namespace quda { Solver::create(x, b); - if (!init) { + if (!init || r.size() != b.size()) { if (!param.is_preconditioner) getProfile().TPSTART(QUDA_PROFILE_INIT); alpha.resize(b.size()); diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index 70f44f3f39..c9d5f2fd27 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -31,7 +31,7 @@ namespace quda { { Solver::create(x, b); - if (!init) { + if (!init || r.size() != b.size()) { getProfile().TPSTART(QUDA_PROFILE_INIT); resize(r, b.size(), QUDA_NULL_FIELD_CREATE, b[0]); diff --git a/lib/inv_gcr_quda.cpp b/lib/inv_gcr_quda.cpp index f040e899be..d18d037d14 100644 --- a/lib/inv_gcr_quda.cpp +++ b/lib/inv_gcr_quda.cpp @@ -140,7 +140,7 @@ namespace quda { { Solver::create(x, b); - if (!init) { + if (!init || r.size() != b.size()) { getProfile().TPSTART(QUDA_PROFILE_INIT); ColorSpinorParam csParam(x[0]); csParam.create = QUDA_NULL_FIELD_CREATE; diff --git a/lib/inv_mr_quda.cpp b/lib/inv_mr_quda.cpp index aee322bd21..032950e655 100644 --- a/lib/inv_mr_quda.cpp +++ b/lib/inv_mr_quda.cpp @@ -24,7 +24,7 @@ namespace quda { Solver::create(x, b); - if (!init) { + if (!init || r.size() != b.size()) { resize(r, b.size(), QUDA_NULL_FIELD_CREATE, b[0]); // now allocate sloppy fields diff --git a/lib/inv_sd_quda.cpp b/lib/inv_sd_quda.cpp index 5a1906ba01..4af2f5f2fa 100644 --- a/lib/inv_sd_quda.cpp +++ b/lib/inv_sd_quda.cpp @@ -15,7 +15,7 @@ namespace quda { { Solver::create(x, b); - if (!init) { + if (!init || r.size() != b.size()) { resize(r, b.size(), QUDA_NULL_FIELD_CREATE, b[0]); resize(Ar, b.size(), QUDA_NULL_FIELD_CREATE, b[0]); init = true; From 2dd2502bee200700b83a66fa58c79dfecbc82220 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 4 Sep 2024 15:23:59 -0700 Subject: [PATCH 056/103] Add iterator-pair constructor for quda::vector class --- include/reference_wrapper_helper.h | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/include/reference_wrapper_helper.h b/include/reference_wrapper_helper.h index 50dc066701..025a623429 100644 --- a/include/reference_wrapper_helper.h +++ b/include/reference_wrapper_helper.h @@ -490,6 +490,17 @@ namespace quda vector() = default; vector(uint64_t size, const T &value = {}) : std::vector(size, value) { } + /** + Constructor from pair of iterators + @param[in] first Begin iterator + @param[in] last End iterator + */ + template > * = nullptr> vector(U first, U last) + { + std::vector::reserve(last - first); + for (auto it = first; it != last; it++) std::vector::push_back(*it); + } + /** @brief Constructor using std::vector initialization @param[in] u Vector we are copying from From 3ad7a573860e1051c41bbdf3b4637facae052d5c Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 4 Sep 2024 17:52:38 -0700 Subject: [PATCH 057/103] MRHS optimizations for eigensolver: exposed new parameter QudaInvertParam::eval_block_size to set the batch size for computing the evalues. --- include/quda.h | 2 + lib/check_params.h | 2 + lib/eigensolve_quda.cpp | 76 ++++++++++++++++++----------- tests/utils/command_line_params.cpp | 5 ++ tests/utils/command_line_params.h | 2 + tests/utils/set_params.cpp | 2 + 6 files changed, 60 insertions(+), 29 deletions(-) diff --git a/include/quda.h b/include/quda.h index 7511271f18..018f4b8470 100644 --- a/include/quda.h +++ b/include/quda.h @@ -554,6 +554,8 @@ extern "C" { int batched_rotate; /** For block method solvers, the block size **/ int block_size; + /** The block size when computing eigenvalues **/ + int eval_block_size; /** For block method solvers, quit after n attempts at block orthonormalisation **/ int max_ortho_attempts; /** For hybrid modifeld Gram-Schmidt orthonormalisations **/ diff --git a/lib/check_params.h b/lib/check_params.h index e692bd1f1e..aa17b32ab3 100644 --- a/lib/check_params.h +++ b/lib/check_params.h @@ -196,6 +196,7 @@ void printQudaEigParam(QudaEigParam *param) { P(eig_type, QUDA_EIG_TR_LANCZOS); P(extlib_type, QUDA_EIGEN_EXTLIB); P(mem_type_ritz, QUDA_MEMORY_DEVICE); + P(eval_block_size, 4); P(ortho_block_size, 0); P(partfile, QUDA_BOOLEAN_FALSE); #else @@ -226,6 +227,7 @@ void printQudaEigParam(QudaEigParam *param) { P(eig_type, QUDA_EIG_INVALID); P(extlib_type, QUDA_EXTLIB_INVALID); P(mem_type_ritz, QUDA_MEMORY_INVALID); + P(eval_block_size, INVALID_INT); P(ortho_block_size, INVALID_INT); P(partfile, QUDA_BOOLEAN_INVALID); #endif diff --git a/lib/eigensolve_quda.cpp b/lib/eigensolve_quda.cpp index 9314afafe6..62e46257c2 100644 --- a/lib/eigensolve_quda.cpp +++ b/lib/eigensolve_quda.cpp @@ -142,11 +142,11 @@ namespace quda { // Use 0th vector to extract meta data for the RNG. RNG rng(kSpace[0], 1234); + // If the spinor contains valid initial data from the user preserve it, else populate with rands. + // We use `!isfinite || norm == 0` instead of `isnormal` because subnormal vectors are still numerically legal + auto norm = blas::norm2({kSpace.begin(), kSpace.begin() + block_size}); for (int b = 0; b < block_size; b++) { - // If the spinor contains valid initial data from the user preserve it, else populate with rands. - // We use `!isfinite || norm == 0` instead of `isnormal` because subnormal vectors are still numerically legal - auto norm = blas::norm2(kSpace[b]); - if (!std::isfinite(norm) || norm == 0.0) { spinorNoise(kSpace[b], rng, QUDA_NOISE_UNIFORM); } + if (!std::isfinite(norm[b]) || norm[b] == 0.0) { spinorNoise(kSpace[b], rng, QUDA_NOISE_UNIFORM); } } bool orthed = false; @@ -496,13 +496,14 @@ namespace quda { logQuda(QUDA_SUMMARIZE, "Computing SVD of M\n"); + auto block_size = eig_param->eval_block_size; int n_conv = eig_param->n_conv; if (evecs.size() < (unsigned int)(2 * n_conv)) errorQuda("Incorrect deflation space sized %d passed to computeSVD, expected %d", (int)(evecs.size()), 2 * n_conv); - std::vector sigma_tmp(n_conv); - - for (int i = 0; i < n_conv; i++) { + for (int i = 0; i < n_conv; i += block_size) { + auto lower = i; + auto upper = i + block_size < n_conv ? i + block_size : n_conv; // This function assumes that you have computed the eigenvectors // of MdagM(MMdag), ie, the right(left) SVD of M. The ith eigen vector in the @@ -515,22 +516,27 @@ namespace quda //-------------------------------------------------------------------------- // Lambda already contains the square root of the eigenvalue of the norm op. - Complex lambda = evals[i]; // M*Rev_i = M*Rsv_i = sigma_i Lsv_i - mat.Expose()->M(evecs[n_conv + i], evecs[i]); + mat.Expose()->M({evecs.begin() + n_conv + lower, evecs.begin() + n_conv + upper}, + {evecs.begin() + lower, evecs.begin() + upper}); // sigma_i = sqrt(sigma_i (Lsv_i)^dag * sigma_i * Lsv_i ) - sigma_tmp[i] = sqrt(blas::norm2(evecs[n_conv + i])); + auto sigma = blas::norm2({evecs.begin() + n_conv + lower, evecs.begin() + n_conv + upper}); + decltype(sigma) sigma_inv(sigma.size()); + for (auto j = 0u; j < sigma.size(); j++) { + sigma[j] = sqrt(sigma[j]); + sigma_inv[j] = 1.0 / sigma[j]; + } // Normalise the Lsv: sigma_i Lsv_i -> Lsv_i - blas::ax(1.0 / sigma_tmp[i], evecs[n_conv + i]); - - logQuda(QUDA_SUMMARIZE, "Sval[%04d] = %+.16e sigma - sqrt(|lambda|) = %+.16e\n", i, sigma_tmp[i], - sigma_tmp[i] - sqrt(abs(lambda.real()))); + blas::ax(sigma_inv, {evecs.begin() + n_conv + lower, evecs.begin() + n_conv + upper}); - evals[i] = sigma_tmp[i]; - //-------------------------------------------------------------------------- + for (auto j = 0u; j < sigma.size(); j++) { + logQuda(QUDA_SUMMARIZE, "Sval[%04d] = %+.16e sigma - sqrt(|lambda|) = %+.16e\n", i + j, sigma[j], + sigma[j] - sqrt(abs(evals[i + j].real()))); + evals[i + j] = sigma[j]; + } } } @@ -574,33 +580,45 @@ namespace quda void EigenSolver::computeEvals(std::vector &evecs, std::vector &evals, int size) { - if (size > (int)evecs.size()) + auto block_size = eig_param->eval_block_size; + + if (size > static_cast(evecs.size())) errorQuda("Requesting %d eigenvectors with only storage allocated for %lu", size, evecs.size()); + + // allocate space if needed for computing the evals + if (size + block_size > static_cast(evecs.size())) resize(evecs, size + block_size, QUDA_NULL_FIELD_CREATE); + // we make sure that we have enough space for eigenvalues // required for coarse-grid deflated solver used from within tmLQCD or PLEGMA with // `preserve_deflation` enabled if (size > (int)evals.size()) evals.resize(size); - ColorSpinorParam csParamClone(evecs[0]); - csParamClone.create = QUDA_NULL_FIELD_CREATE; - ColorSpinorField temp(csParamClone); + for (int i = 0; i < size; i += block_size) { + auto lower = i; + auto upper = i + block_size < size ? i + block_size : size; + + auto temp = {evecs.begin() + size, evecs.begin() + size + upper - lower}; - for (int i = 0; i < size; i++) { // r = A * v_i - mat(temp, evecs[i]); + mat(temp, {evecs.begin() + lower, evecs.begin() + upper}); // lambda_i = v_i^dag A v_i / (v_i^dag * v_i) - evals[i] = blas::cDotProduct(evecs[i], temp) / sqrt(blas::norm2(evecs[i])); + auto vtAv = blas::cDotProduct({evecs.begin() + lower, evecs.begin() + upper}, temp); + auto v2 = blas::norm2({evecs.begin() + lower, evecs.begin() + upper}); + for (auto j = 0u; j < v2.size(); j++) evals[i + j] = vtAv[j] / sqrt(v2[j]); // Measure ||lambda_i*v_i - A*v_i|| Complex n_unit(-1.0, 0.0); - blas::caxpby(evals[i], evecs[i], n_unit, temp); - residua[i] = sqrt(blas::norm2(temp)); - // eig_param->invert_param->true_res_offset[i] = residua[i]; + auto res = blas::caxpbyNorm({evals.begin() + lower, evals.begin() + upper}, + {evecs.begin() + lower, evecs.begin() + upper}, n_unit, temp); + for (auto j = 0u; j < v2.size(); j++) residua[i + j] = sqrt(res[j]); // If size = n_conv, this routine is called post sort - if (size == n_conv) - logQuda(QUDA_SUMMARIZE, "Eval[%04d] = (%+.16e,%+.16e) ||%+.16e|| Residual = %+.16e\n", i, evals[i].real(), - evals[i].imag(), abs(evals[i]), residua[i]); + if (size == n_conv) { + for (int j = lower; j < upper; j++) { + logQuda(QUDA_SUMMARIZE, "Eval[%04d] = (%+.16e,%+.16e) ||%+.16e|| Residual = %+.16e\n", j, evals[j].real(), + evals[j].imag(), abs(evals[j]), residua[j]); + } + } } } diff --git a/tests/utils/command_line_params.cpp b/tests/utils/command_line_params.cpp index fe937da3e5..3e6247e3e0 100644 --- a/tests/utils/command_line_params.cpp +++ b/tests/utils/command_line_params.cpp @@ -213,6 +213,7 @@ QudaMemoryType mem_type_ritz = QUDA_MEMORY_DEVICE; // Parameters for the stand alone eigensolver int eig_ortho_block_size = 0; +int eig_eval_block_size = 4; int eig_block_size = 4; int eig_n_ev = 16; int eig_n_kr = 32; @@ -250,6 +251,7 @@ bool eig_partfile = false; // all others are for PR vectors. quda::mgarray mg_eig = {}; quda::mgarray mg_eig_ortho_block_size = {}; +quda::mgarray mg_eig_eval_block_size = {}; quda::mgarray mg_eig_block_size = {}; quda::mgarray mg_eig_n_ev_deflate = {}; quda::mgarray mg_eig_n_ev = {}; @@ -775,6 +777,7 @@ void add_eigen_option_group(std::shared_ptr quda_app) opgroup->add_option("--eig-ortho-block-size", eig_ortho_block_size, "The block size to use when orthonormalising vectors in hybrid modified Gram-Schmidt" "0 for always Classical, 1 for Modified, n > 1 for Hybrid)"); + opgroup->add_option("--eig-eval-block-size", eig_eval_block_size, "The block size used when computing eigenvalues in the eigensolver"); opgroup->add_option("--eig-block-size", eig_block_size, "The block size to use in the block variant eigensolver"); opgroup->add_option( "--eig-n-ev-deflate", eig_n_ev_deflate, @@ -919,6 +922,8 @@ void add_multigrid_option_group(std::shared_ptr quda_app) "Use Eigen to eigensolve the upper Hessenberg in IRAM, else use QUDA's QR code. (default true)"); quda_app->add_mgoption(opgroup, "--mg-eig-ortho-block-size", mg_eig_ortho_block_size, CLI::Validator(), "The block size to use when orthonormalising vectors in hybrid modified Gram-Schmidt"); + quda_app->add_mgoption(opgroup, "--mg-eig-eval-block-size", mg_eig_eval_block_size, CLI::Validator(), + "The block size used when computing eigenvalues in the eigensolver"); quda_app->add_mgoption(opgroup, "--mg-eig-block-size", mg_eig_block_size, CLI::Validator(), "The block size to use in the block variant eigensolver"); quda_app->add_mgoption(opgroup, "--mg-eig-n-ev", mg_eig_n_ev, CLI::Validator(), diff --git a/tests/utils/command_line_params.h b/tests/utils/command_line_params.h index ecd20ae065..647f171591 100644 --- a/tests/utils/command_line_params.h +++ b/tests/utils/command_line_params.h @@ -461,6 +461,7 @@ extern QudaMemoryType mem_type_ritz; // Parameters for the stand alone eigensolver extern int eig_ortho_block_size; +extern int eig_eval_block_size; extern int eig_block_size; extern int eig_n_ev; extern int eig_n_kr; @@ -498,6 +499,7 @@ extern bool eig_partfile; // all others are for PR vectors. extern quda::mgarray mg_eig; extern quda::mgarray mg_eig_ortho_block_size; +extern quda::mgarray mg_eig_eval_block_size; extern quda::mgarray mg_eig_block_size; extern quda::mgarray mg_eig_n_ev_deflate; extern quda::mgarray mg_eig_n_ev; diff --git a/tests/utils/set_params.cpp b/tests/utils/set_params.cpp index 6c9f10c597..5b6a4eb4bd 100644 --- a/tests/utils/set_params.cpp +++ b/tests/utils/set_params.cpp @@ -335,6 +335,7 @@ void setEigParam(QudaEigParam &eig_param) } eig_param.ortho_block_size = eig_ortho_block_size; + eig_param.eval_block_size = eig_eval_block_size; eig_param.block_size = (eig_param.eig_type == QUDA_EIG_TR_LANCZOS || eig_param.eig_type == QUDA_EIG_IR_ARNOLDI) ? 1 : eig_block_size; eig_param.n_ev = eig_n_ev; @@ -786,6 +787,7 @@ void setMultigridEigParam(QudaEigParam &mg_eig_param, int level) } mg_eig_param.ortho_block_size = mg_eig_ortho_block_size[level]; + mg_eig_param.eval_block_size = mg_eig_eval_block_size[level]; mg_eig_param.block_size = (mg_eig_param.eig_type == QUDA_EIG_TR_LANCZOS || mg_eig_param.eig_type == QUDA_EIG_IR_ARNOLDI) ? 1 : From a386654c45b27966af84eb8a9e7b195b2a83084b Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 4 Sep 2024 17:54:31 -0700 Subject: [PATCH 058/103] Preserve eigen space when running multi-src deflated solves --- tests/invert_test.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/invert_test.cpp b/tests/invert_test.cpp index 9bc95b381b..92b68e71e9 100644 --- a/tests/invert_test.cpp +++ b/tests/invert_test.cpp @@ -347,6 +347,9 @@ std::vector> solve(test_t param) std::vector _hp_b(Nsrc_tile); for (int j = 0; j < Nsrc; j += Nsrc_tile) { + // If deflating, preserve the deflation space between solves + if (inv_deflate) eig_param.preserve_deflation = j < Nsrc - Nsrc_tile ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; + for (int i = 0; i < Nsrc_tile; i++) { _hp_x[i] = out[j + i].data(); _hp_b[i] = in[j + i].data(); From 0ed1fe9ee31b5b65799013b9fb13f7124ce4dc98 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Thu, 5 Sep 2024 17:02:00 -0700 Subject: [PATCH 059/103] Fix CI warnings (one of which was a real bug) --- include/eigen_helper.h | 8 ++++++++ include/invert_quda.h | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/include/eigen_helper.h b/include/eigen_helper.h index b3eb57baf0..da4d054449 100644 --- a/include/eigen_helper.h +++ b/include/eigen_helper.h @@ -11,14 +11,22 @@ #include +#define GCC_COMPILER (defined(__GNUC__) && !defined(__clang__)) + // hide annoying warning +#ifdef GCC_COMPILER #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif #include #include #include +#ifdef GCC_COMPILER #pragma GCC diagnostic pop +#endif + +#undef GCC_COMPILER using namespace Eigen; diff --git a/include/invert_quda.h b/include/invert_quda.h index 62d9019523..5b9e5052c5 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -894,7 +894,7 @@ namespace quda { public: CG3(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m); - void operator()(cvector_ref &out, cvector_ref &in) override + virtual void operator()(cvector_ref &out, cvector_ref &in) override { for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); } @@ -969,7 +969,7 @@ namespace quda { public: CG3NR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m); - void operator()(cvector &out, cvector_ref &in) + void operator()(cvector_ref &out, cvector_ref &in) override { for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); } From 75a3b4c17cdf4e7e430c5989c090eaeabb8266bb Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Thu, 5 Sep 2024 22:13:05 -0700 Subject: [PATCH 060/103] More CI warnings --- include/accelerator.h | 2 +- include/eigen_helper.h | 8 ++------ include/invert_quda.h | 2 +- lib/multigrid.cpp | 2 -- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/include/accelerator.h b/include/accelerator.h index 53b5a811c7..a1113c57be 100644 --- a/include/accelerator.h +++ b/include/accelerator.h @@ -69,7 +69,7 @@ namespace quda * @param null Solver to solve for null vectors. * @param in meta color spinor field. */ - virtual void train_param(Solver &null, const ColorSpinorField &in) + virtual void train_param(Solver &null, const ColorSpinorField &in) override { if (!active_training && !transformer.trained) { active_training = true; diff --git a/include/eigen_helper.h b/include/eigen_helper.h index da4d054449..8c1bf012a8 100644 --- a/include/eigen_helper.h +++ b/include/eigen_helper.h @@ -11,10 +11,8 @@ #include -#define GCC_COMPILER (defined(__GNUC__) && !defined(__clang__)) - // hide annoying warning -#ifdef GCC_COMPILER +#if !defined(__clang__) && !defined(_NVHPC_CUDA) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" #endif @@ -23,10 +21,8 @@ #include #include -#ifdef GCC_COMPILER +#if !defined(__clang__) && !defined(_NVHPC_CUDA) #pragma GCC diagnostic pop #endif -#undef GCC_COMPILER - using namespace Eigen; diff --git a/include/invert_quda.h b/include/invert_quda.h index 5b9e5052c5..ba29ca92c3 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -979,7 +979,7 @@ namespace quda { /** @return Return the residual vector from the prior solve */ - cvector_ref get_residual(); + cvector_ref get_residual() override; virtual bool hermitian() const final { return false; } /** CG3NR is for any system */ diff --git a/lib/multigrid.cpp b/lib/multigrid.cpp index 0fa224a8fd..e36900f155 100644 --- a/lib/multigrid.cpp +++ b/lib/multigrid.cpp @@ -13,8 +13,6 @@ namespace quda using namespace blas; - static constexpr bool debug = false; - MG::MG(MGParam ¶m) : Solver(*param.matResidual, *param.matSmooth, *param.matSmoothSloppy, *param.matSmoothSloppy, param), param(param), From 37a6ae93bc97fa6787090bbe0c461aee3caf0116 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 6 Sep 2024 12:05:54 -0700 Subject: [PATCH 061/103] Fix some outstanding CI issues --- tests/utils/misc.cpp | 6 +++--- tests/utils/set_params.cpp | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/utils/misc.cpp b/tests/utils/misc.cpp index 5f7b3e6ddd..8ec29c514b 100644 --- a/tests/utils/misc.cpp +++ b/tests/utils/misc.cpp @@ -133,10 +133,10 @@ std::vector get_dslash_str_list() dslash_str_list.push_back("staggered"); dslash_str_list.push_back("asqtad"); dslash_str_list.push_back("hisq"); - dslash_str_list.push_back("domain_wall"); - dslash_str_list.push_back("domain_wall_4d"); + dslash_str_list.push_back("domain-wall"); + dslash_str_list.push_back("domain-wall-4d"); dslash_str_list.push_back("mobius"); - dslash_str_list.push_back("mobius_eofa"); + dslash_str_list.push_back("mobius-eofa"); dslash_str_list.push_back("laplace"); populated = true; } diff --git a/tests/utils/set_params.cpp b/tests/utils/set_params.cpp index 5b6a4eb4bd..e46b40e804 100644 --- a/tests/utils/set_params.cpp +++ b/tests/utils/set_params.cpp @@ -950,6 +950,8 @@ void setStaggeredInvertParam(QudaInvertParam &inv_param) inv_param.use_sloppy_partial_accumulator = false; inv_param.solution_accumulator_pipeline = solution_accumulator_pipeline; inv_param.pipeline = pipeline; + inv_param.max_res_increase = max_res_increase; + inv_param.max_res_increase_total = max_res_increase_total; inv_param.Ls = 1; From fc0762d6152ff9b3f61b4c5d1787273c3d64ca02 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 6 Sep 2024 13:25:56 -0700 Subject: [PATCH 062/103] renaming as suggested in CI --- include/quda.h | 4 ++-- lib/check_params.h | 4 ++-- lib/eigensolve_quda.cpp | 14 +++++++------- tests/utils/command_line_params.cpp | 8 ++++---- tests/utils/command_line_params.h | 4 ++-- tests/utils/set_params.cpp | 4 ++-- 6 files changed, 19 insertions(+), 19 deletions(-) diff --git a/include/quda.h b/include/quda.h index 018f4b8470..3481818cc4 100644 --- a/include/quda.h +++ b/include/quda.h @@ -554,8 +554,8 @@ extern "C" { int batched_rotate; /** For block method solvers, the block size **/ int block_size; - /** The block size when computing eigenvalues **/ - int eval_block_size; + /** The batch size used when computing eigenvalues **/ + int compute_evals_batch_size; /** For block method solvers, quit after n attempts at block orthonormalisation **/ int max_ortho_attempts; /** For hybrid modifeld Gram-Schmidt orthonormalisations **/ diff --git a/lib/check_params.h b/lib/check_params.h index aa17b32ab3..b1fe81fc28 100644 --- a/lib/check_params.h +++ b/lib/check_params.h @@ -196,7 +196,7 @@ void printQudaEigParam(QudaEigParam *param) { P(eig_type, QUDA_EIG_TR_LANCZOS); P(extlib_type, QUDA_EIGEN_EXTLIB); P(mem_type_ritz, QUDA_MEMORY_DEVICE); - P(eval_block_size, 4); + P(compute_evals_batch_size, 4); P(ortho_block_size, 0); P(partfile, QUDA_BOOLEAN_FALSE); #else @@ -227,7 +227,7 @@ void printQudaEigParam(QudaEigParam *param) { P(eig_type, QUDA_EIG_INVALID); P(extlib_type, QUDA_EXTLIB_INVALID); P(mem_type_ritz, QUDA_MEMORY_INVALID); - P(eval_block_size, INVALID_INT); + P(compute_evals_batch_size, INVALID_INT); P(ortho_block_size, INVALID_INT); P(partfile, QUDA_BOOLEAN_INVALID); #endif diff --git a/lib/eigensolve_quda.cpp b/lib/eigensolve_quda.cpp index 62e46257c2..8130897cbb 100644 --- a/lib/eigensolve_quda.cpp +++ b/lib/eigensolve_quda.cpp @@ -496,14 +496,14 @@ namespace quda { logQuda(QUDA_SUMMARIZE, "Computing SVD of M\n"); - auto block_size = eig_param->eval_block_size; + auto batch_size = eig_param->compute_evals_batch_size; int n_conv = eig_param->n_conv; if (evecs.size() < (unsigned int)(2 * n_conv)) errorQuda("Incorrect deflation space sized %d passed to computeSVD, expected %d", (int)(evecs.size()), 2 * n_conv); - for (int i = 0; i < n_conv; i += block_size) { + for (int i = 0; i < n_conv; i += batch_size) { auto lower = i; - auto upper = i + block_size < n_conv ? i + block_size : n_conv; + auto upper = i + batch_size < n_conv ? i + batch_size : n_conv; // This function assumes that you have computed the eigenvectors // of MdagM(MMdag), ie, the right(left) SVD of M. The ith eigen vector in the @@ -580,22 +580,22 @@ namespace quda void EigenSolver::computeEvals(std::vector &evecs, std::vector &evals, int size) { - auto block_size = eig_param->eval_block_size; + auto batch_size = eig_param->compute_evals_batch_size; if (size > static_cast(evecs.size())) errorQuda("Requesting %d eigenvectors with only storage allocated for %lu", size, evecs.size()); // allocate space if needed for computing the evals - if (size + block_size > static_cast(evecs.size())) resize(evecs, size + block_size, QUDA_NULL_FIELD_CREATE); + if (size + batch_size > static_cast(evecs.size())) resize(evecs, size + batch_size, QUDA_NULL_FIELD_CREATE); // we make sure that we have enough space for eigenvalues // required for coarse-grid deflated solver used from within tmLQCD or PLEGMA with // `preserve_deflation` enabled if (size > (int)evals.size()) evals.resize(size); - for (int i = 0; i < size; i += block_size) { + for (int i = 0; i < size; i += batch_size) { auto lower = i; - auto upper = i + block_size < size ? i + block_size : size; + auto upper = i + batch_size < size ? i + batch_size : size; auto temp = {evecs.begin() + size, evecs.begin() + size + upper - lower}; diff --git a/tests/utils/command_line_params.cpp b/tests/utils/command_line_params.cpp index 84d5bf46e3..3c172ebee0 100644 --- a/tests/utils/command_line_params.cpp +++ b/tests/utils/command_line_params.cpp @@ -213,7 +213,7 @@ QudaMemoryType mem_type_ritz = QUDA_MEMORY_DEVICE; // Parameters for the stand alone eigensolver int eig_ortho_block_size = 0; -int eig_eval_block_size = 4; +int eig_evals_batch_size = 4; int eig_block_size = 4; int eig_n_ev = 16; int eig_n_kr = 32; @@ -251,7 +251,7 @@ bool eig_partfile = false; // all others are for PR vectors. quda::mgarray mg_eig = {}; quda::mgarray mg_eig_ortho_block_size = {}; -quda::mgarray mg_eig_eval_block_size = {}; +quda::mgarray mg_eig_evals_batch_size = {}; quda::mgarray mg_eig_block_size = {}; quda::mgarray mg_eig_n_ev_deflate = {}; quda::mgarray mg_eig_n_ev = {}; @@ -791,7 +791,7 @@ void add_eigen_option_group(std::shared_ptr quda_app) opgroup->add_option("--eig-ortho-block-size", eig_ortho_block_size, "The block size to use when orthonormalising vectors in hybrid modified Gram-Schmidt" "0 for always Classical, 1 for Modified, n > 1 for Hybrid)"); - opgroup->add_option("--eig-eval-block-size", eig_eval_block_size, "The block size used when computing eigenvalues in the eigensolver"); + opgroup->add_option("--eig-evals-batch-size", eig_evals_batch_size, "The batch size used when computing eigenvalues in the eigensolver"); opgroup->add_option("--eig-block-size", eig_block_size, "The block size to use in the block variant eigensolver"); opgroup->add_option( "--eig-n-ev-deflate", eig_n_ev_deflate, @@ -936,7 +936,7 @@ void add_multigrid_option_group(std::shared_ptr quda_app) "Use Eigen to eigensolve the upper Hessenberg in IRAM, else use QUDA's QR code. (default true)"); quda_app->add_mgoption(opgroup, "--mg-eig-ortho-block-size", mg_eig_ortho_block_size, CLI::Validator(), "The block size to use when orthonormalising vectors in hybrid modified Gram-Schmidt"); - quda_app->add_mgoption(opgroup, "--mg-eig-eval-block-size", mg_eig_eval_block_size, CLI::Validator(), + quda_app->add_mgoption(opgroup, "--mg-eig-evals-batch-size", mg_eig_evals_batch_size, CLI::Validator(), "The block size used when computing eigenvalues in the eigensolver"); quda_app->add_mgoption(opgroup, "--mg-eig-block-size", mg_eig_block_size, CLI::Validator(), "The block size to use in the block variant eigensolver"); diff --git a/tests/utils/command_line_params.h b/tests/utils/command_line_params.h index 34a6441b45..338dcc73bc 100644 --- a/tests/utils/command_line_params.h +++ b/tests/utils/command_line_params.h @@ -462,7 +462,7 @@ extern QudaMemoryType mem_type_ritz; // Parameters for the stand alone eigensolver extern int eig_ortho_block_size; -extern int eig_eval_block_size; +extern int eig_evals_batch_size; extern int eig_block_size; extern int eig_n_ev; extern int eig_n_kr; @@ -500,7 +500,7 @@ extern bool eig_partfile; // all others are for PR vectors. extern quda::mgarray mg_eig; extern quda::mgarray mg_eig_ortho_block_size; -extern quda::mgarray mg_eig_eval_block_size; +extern quda::mgarray mg_eig_evals_batch_size; extern quda::mgarray mg_eig_block_size; extern quda::mgarray mg_eig_n_ev_deflate; extern quda::mgarray mg_eig_n_ev; diff --git a/tests/utils/set_params.cpp b/tests/utils/set_params.cpp index e46b40e804..bc124c6c31 100644 --- a/tests/utils/set_params.cpp +++ b/tests/utils/set_params.cpp @@ -335,7 +335,7 @@ void setEigParam(QudaEigParam &eig_param) } eig_param.ortho_block_size = eig_ortho_block_size; - eig_param.eval_block_size = eig_eval_block_size; + eig_param.compute_evals_batch_size = eig_evals_batch_size; eig_param.block_size = (eig_param.eig_type == QUDA_EIG_TR_LANCZOS || eig_param.eig_type == QUDA_EIG_IR_ARNOLDI) ? 1 : eig_block_size; eig_param.n_ev = eig_n_ev; @@ -787,7 +787,7 @@ void setMultigridEigParam(QudaEigParam &mg_eig_param, int level) } mg_eig_param.ortho_block_size = mg_eig_ortho_block_size[level]; - mg_eig_param.eval_block_size = mg_eig_eval_block_size[level]; + mg_eig_param.compute_evals_batch_size = mg_eig_evals_batch_size[level]; mg_eig_param.block_size = (mg_eig_param.eig_type == QUDA_EIG_TR_LANCZOS || mg_eig_param.eig_type == QUDA_EIG_IR_ARNOLDI) ? 1 : From 44d8a2a1b17bda0bc4fac1b071d50ac225359d3c Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 6 Sep 2024 13:28:08 -0700 Subject: [PATCH 063/103] Use std::vector iterator constructor --- include/reference_wrapper_helper.h | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/include/reference_wrapper_helper.h b/include/reference_wrapper_helper.h index 025a623429..1dbf9f6df8 100644 --- a/include/reference_wrapper_helper.h +++ b/include/reference_wrapper_helper.h @@ -246,11 +246,8 @@ namespace quda @param[in] first Begin iterator @param[in] last End iterator */ - template >* = nullptr> - vector_ref(U first, U last) + template > * = nullptr> vector_ref(U first, U last) : vector(first, last) { - vector::reserve(last - first); - for (auto it = first; it != last; it++) vector::push_back(*it); } /** @@ -495,10 +492,9 @@ namespace quda @param[in] first Begin iterator @param[in] last End iterator */ - template > * = nullptr> vector(U first, U last) + template > * = nullptr> + vector(U first, U last) : std::vector(first, last) { - std::vector::reserve(last - first); - for (auto it = first; it != last; it++) std::vector::push_back(*it); } /** From e89be7d340ad72d37cec7a83c4d1fc25d8f85a75 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 6 Sep 2024 13:28:38 -0700 Subject: [PATCH 064/103] Revert change made in this branch --- lib/dirac_coarse.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/dirac_coarse.cpp b/lib/dirac_coarse.cpp index 7bf1225f7c..62f4f7639e 100644 --- a/lib/dirac_coarse.cpp +++ b/lib/dirac_coarse.cpp @@ -616,7 +616,7 @@ namespace quda { checkFullSpinor(x, b); auto tmp = getFieldTmp(x.Even()); -#if 1 +#if 0 // x_o = A_oo^-1 (b_o - D_oe x_e) DiracCoarse::Dslash(tmp, x(this_parity), other_parity); blas::xpay(b(other_parity), -1.0, tmp); From dcd0d439a99c31c89a6d2ec8370129df614be2af Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 6 Sep 2024 13:33:54 -0700 Subject: [PATCH 065/103] Cleanup of DiracCloverHasenbuschTwistPC --- lib/dirac_clover_hasenbusch_twist.cpp | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/lib/dirac_clover_hasenbusch_twist.cpp b/lib/dirac_clover_hasenbusch_twist.cpp index 079f2ec5b9..1e7c614526 100644 --- a/lib/dirac_clover_hasenbusch_twist.cpp +++ b/lib/dirac_clover_hasenbusch_twist.cpp @@ -115,10 +115,6 @@ namespace quda double kappa2 = -kappa * kappa; auto tmp = getFieldTmp(out); - bool symmetric = (matpcType == QUDA_MATPC_EVEN_EVEN || matpcType == QUDA_MATPC_ODD_ODD) ? true : false; - int odd_bit = (matpcType == QUDA_MATPC_ODD_ODD || matpcType == QUDA_MATPC_ODD_ODD_ASYMMETRIC) ? 1 : 0; - QudaParity parity[2] = {static_cast((1 + odd_bit) % 2), static_cast((0 + odd_bit) % 2)}; - if (!symmetric) { // No need to change order of calls for dagger // because the asymmetric operator is actually symmetric @@ -126,27 +122,27 @@ namespace quda // the pieces in Dslash and DslashXPay respect the dagger // DiracCloverHasenbuschTwistPC::Dslash applies A^{-1}Dslash - Dslash(tmp, in, parity[0]); + Dslash(tmp, in, other_parity); // applies (A + imu*g5 - kappa^2 D)- - ApplyTwistedClover(out, tmp, *gauge, *clover, kappa2, mu, in, parity[1], dagger, commDim.data, profile); + ApplyTwistedClover(out, tmp, *gauge, *clover, kappa2, mu, in, this_parity, dagger, commDim.data, profile); } else if (!dagger) { // symmetric preconditioning // We need two cases because M = 1-ADAD and M^\dag = 1-D^\dag A D^dag A // where A is actually a clover inverse. // This is the non-dag case: AD - Dslash(tmp, in, parity[0]); + Dslash(tmp, in, other_parity); // Then x + AD (AD) - DslashXpayTwistClovInv(out, tmp, parity[1], in, kappa2, mu); + DslashXpayTwistClovInv(out, tmp, this_parity, in, kappa2, mu); } else { // symmetric preconditioning, dagger // This is the dagger: 1 - DADA // i) Apply A - CloverInv(out, in, parity[1]); + CloverInv(out, in, this_parity); // ii) Apply A D => ADA - Dslash(tmp, out, parity[0]); + Dslash(tmp, out, other_parity); // iii) Apply x + D(ADA) - DslashXpayTwistNoClovInv(out, tmp, parity[1], in, kappa2, mu); + DslashXpayTwistNoClovInv(out, tmp, this_parity, in, kappa2, mu); } } From f0d9f3c8dc79bdb4ed751b34193262c517a6b287 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 6 Sep 2024 16:25:16 -0700 Subject: [PATCH 066/103] Cleanup ero source checking in the solvers --- include/invert_quda.h | 15 ++++++++++++++- lib/inv_ca_gcr.cpp | 19 +++---------------- lib/inv_cg_quda.cpp | 32 ++++++-------------------------- lib/inv_gcr_quda.cpp | 31 +++++++++++-------------------- lib/inv_sd_quda.cpp | 15 +-------------- lib/solve.cpp | 1 + lib/solver.cpp | 20 ++++++++++++++++++++ tests/staggered_invert_test.cpp | 1 + 8 files changed, 57 insertions(+), 77 deletions(-) diff --git a/include/invert_quda.h b/include/invert_quda.h index ba29ca92c3..fbe3234e7c 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -424,6 +424,19 @@ namespace quda { bool mixed() { return param.precision != param.precision_sloppy; } + /** + @brief Check the support of each source field, and return true + if all fields in the set have zero support. If we are doing + null-space finding, this function always returns false. If a + given source vector does have zero support, then we set the + matching solution vector to match. + @param[in] x Solution vector set + @param[in] b Source vector set + @param[in] b2 Vector of norms + @return boolean if all vectors have zero support + */ + bool is_zero_src(cvector_ref &x, cvector_ref &b, cvector &b2); + public: Solver(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); @@ -551,7 +564,7 @@ namespace quda { } /** - @briefTest for solver convergence + @brief Test for solver convergence @param[in] r2 L2 norm squared of the residual @param[in] hq2 Heavy quark residual @param[in] r2_tol Solver L2 tolerance diff --git a/lib/inv_ca_gcr.cpp b/lib/inv_ca_gcr.cpp index 07fdd9fe0d..53d71bdaee 100644 --- a/lib/inv_ca_gcr.cpp +++ b/lib/inv_ca_gcr.cpp @@ -237,22 +237,9 @@ namespace quda double b_map = -(lambda_max + lambda_min) / (lambda_max - lambda_min); // Check to see that we're not trying to invert on a zero-field source - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - bool zero_src = true; - for (auto i = 0u; i < b.size(); i++) { - if (b2[i] == 0) { - warningQuda("inverting on zero-field source"); - x[i] = b[i]; - param.true_res[i] = 0.0; - param.true_res_hq[i] = 0.0; - } else { - zero_src = false; - } - } - if (zero_src) { - getProfile().TPSTOP(QUDA_PROFILE_INIT); - return; - } + if (is_zero_src(x, b, b2)) { + getProfile().TPSTOP(QUDA_PROFILE_INIT); + return; } auto stop = !fixed_iteration ? stopping(param.tol, b2, param.residual_type) : diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index c9d5f2fd27..18ebd4c851 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -255,19 +255,9 @@ namespace quda { vector b2 = blas::norm2(b); // Check to see that we're not trying to invert on a zero-field source - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - bool zero_src = true; - for (auto i = 0u; i < b.size(); i++) { - if (b2[i] == 0) { - warningQuda("inverting on zero-field source"); - x[i] = b[i]; - param.true_res[i] = 0.0; - param.true_res_hq[i] = 0.0; - } else { - zero_src = false; - } - } - if (zero_src) return; + if (is_zero_src(x, b, b2)) { + getProfile().TPSTOP(QUDA_PROFILE_INIT); + return; } create(x, b); @@ -622,19 +612,9 @@ namespace quda { bool heavy_quark_restart = false; // Check to see that we're not trying to invert on a zero-field source - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - bool zero_src = true; - for (auto i = 0u; i < b.size(); i++) { - if (b2[i] == 0) { - warningQuda("inverting on zero-field source"); - x[i] = b[i]; - param.true_res[i] = 0.0; - param.true_res_hq[i] = 0.0; - } else { - zero_src = false; - } - } - if (zero_src) return; + if (is_zero_src(x, b, b2)) { + getProfile().TPSTOP(QUDA_PROFILE_INIT); + return; } create(x, b); diff --git a/lib/inv_gcr_quda.cpp b/lib/inv_gcr_quda.cpp index d18d037d14..d7df9eb4fc 100644 --- a/lib/inv_gcr_quda.cpp +++ b/lib/inv_gcr_quda.cpp @@ -255,22 +255,9 @@ namespace quda { } // Check to see that we're not trying to invert on a zero-field source - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - bool zero_src = true; - for (auto i = 0u; i < b.size(); i++) { - if (b2[i] == 0) { - warningQuda("inverting on zero-field source"); - x[i] = b[i]; - param.true_res[i] = 0.0; - param.true_res_hq[i] = 0.0; - } else { - zero_src = false; - } - } - if (zero_src) { - getProfile().TPSTOP(QUDA_PROFILE_INIT); - return; - } + if (is_zero_src(x, b, b2)) { + getProfile().TPSTOP(QUDA_PROFILE_INIT); + return; } auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver @@ -285,8 +272,10 @@ namespace quda { const int maxResIncreaseTotal = param.max_res_increase_total; std::vector heavy_quark_res(b.size()); // heavy quark residual - if (use_heavy_quark_res) - for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(blas::HeavyQuarkResidualNorm(x, r)[i].z); + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } int resIncrease = 0; int resIncreaseTotal = 0; @@ -376,8 +365,10 @@ namespace quda { maxr_deflate = sqrt(r2[0]); } - if (use_heavy_quark_res) - for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(blas::HeavyQuarkResidualNorm(x, r)[i].z); + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } // break-out check if we have reached the limit of the precision if (r2 > r2_old) { diff --git a/lib/inv_sd_quda.cpp b/lib/inv_sd_quda.cpp index 4af2f5f2fa..929064ee28 100644 --- a/lib/inv_sd_quda.cpp +++ b/lib/inv_sd_quda.cpp @@ -38,20 +38,7 @@ namespace quda { vector r2; // Check to see that we're not trying to invert on a zero-field source - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - bool zero_src = true; - for (auto i = 0u; i < b.size(); i++) { - if (b2[i] == 0) { - warningQuda("inverting on zero-field source"); - x[i] = b[i]; - param.true_res[i] = 0.0; - param.true_res_hq[i] = 0.0; - } else { - zero_src = false; - } - } - if (zero_src) return; - } + if (is_zero_src(x, b, b2)) return; if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { // Compute the true residual diff --git a/lib/solve.cpp b/lib/solve.cpp index 730152f60d..938e3cfecd 100644 --- a/lib/solve.cpp +++ b/lib/solve.cpp @@ -407,4 +407,5 @@ namespace quda popVerbosity(); } + } // namespace quda diff --git a/lib/solver.cpp b/lib/solver.cpp index b4f84da600..f32cf8a9fd 100644 --- a/lib/solver.cpp +++ b/lib/solver.cpp @@ -544,4 +544,24 @@ namespace quda { } } + // check we're not solving on a zero-valued source + bool Solver::is_zero_src(cvector_ref &x, cvector_ref &b, cvector &b2) + { + // if computing null vectors then zero sources are fine + if (param.compute_null_vector != QUDA_COMPUTE_NULL_VECTOR_NO) return false; + + bool zero_src = true; + for (auto i = 0u; i < b.size(); i++) { + if (b2[i] == 0) { + warningQuda("source %d is zero", i); + x[i] = b[i]; + param.true_res[i] = 0.0; + param.true_res_hq[i] = 0.0; + } else { + zero_src = false; + } + } + return zero_src; + } + } // namespace quda diff --git a/tests/staggered_invert_test.cpp b/tests/staggered_invert_test.cpp index 460fe7de48..d512c5463e 100644 --- a/tests/staggered_invert_test.cpp +++ b/tests/staggered_invert_test.cpp @@ -49,6 +49,7 @@ void display_test_info() printfQuda(" - number of levels %d\n", mg_levels); for (int i = 0; i < mg_levels - 1; i++) { printfQuda(" - level %d number of null-space vectors %d\n", i + 1, nvec[i]); + printfQuda(" - level %d null-space vector batch size %d\n", i + 1, nvec_batch[i]); printfQuda(" - level %d number of pre-smoother applications %d\n", i + 1, nu_pre[i]); printfQuda(" - level %d number of post-smoother applications %d\n", i + 1, nu_post[i]); } From 36c9b6db7b8edd0ec7cabd990e5bce6604f37644 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 9 Sep 2024 22:48:52 -0700 Subject: [PATCH 067/103] CGNR and CGNE are now MRHS --- include/invert_quda.h | 24 ++++-------- lib/inv_cg_quda.cpp | 90 +++++++++++++++++++++++++------------------ 2 files changed, 60 insertions(+), 54 deletions(-) diff --git a/include/invert_quda.h b/include/invert_quda.h index fbe3234e7c..9302d968ac 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -812,8 +812,8 @@ namespace quda { DiracMMdag mmdagSloppy; DiracMMdag mmdagPrecon; DiracMMdag mmdagEig; - ColorSpinorField xp; - ColorSpinorField yp; + std::vector xe; + std::vector ye; bool init = false; /** @@ -821,18 +821,13 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: CGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); - void operator()(cvector_ref &out, cvector_ref &in) override - { - for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); - } - - void operator()(ColorSpinorField &out, const ColorSpinorField &in); + void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual vector from the prior solve @@ -852,7 +847,7 @@ namespace quda { DiracMdagM mdagmSloppy; DiracMdagM mdagmPrecon; DiracMdagM mdagmEig; - ColorSpinorField br; + std::vector br; bool init = false; /** @@ -860,18 +855,13 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: CGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); - void operator()(cvector_ref &out, cvector_ref &in) override - { - for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); - } - - void operator()(ColorSpinorField &out, const ColorSpinorField &in); + void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual vector from the prior solve diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index 18ebd4c851..b6c7c6a593 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -70,15 +70,15 @@ namespace quda { { } - void CGNE::create(ColorSpinorField &x, const ColorSpinorField &b) + void CGNE::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); if (!init) { - ColorSpinorParam csParam(x); + ColorSpinorParam csParam(x[0]); csParam.create = QUDA_NULL_FIELD_CREATE; - xp = ColorSpinorField(csParam); + resize(xe, b.size(), csParam); csParam.create = QUDA_ZERO_FIELD_CREATE; - yp = ColorSpinorField(csParam); + resize(ye, b.size(), csParam); init = true; } } @@ -88,11 +88,11 @@ namespace quda { if (!init) errorQuda("No residual vector present"); if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); // CG residual will match the CGNE residual (FIXME: but only with zero initial guess?) - return param.use_init_guess ? xp : CG::get_residual(); + return param.use_init_guess ? xe : CG::get_residual(); } // CGNE: M Mdag y = b is solved; x = Mdag y is returned as solution. - void CGNE::operator()(ColorSpinorField &x, const ColorSpinorField &b) + void CGNE::operator()(cvector_ref &x, cvector_ref &b) { if (param.maxiter == 0 || param.Nsteps == 0) { if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); @@ -102,43 +102,52 @@ namespace quda { create(x, b); const int iter0 = param.iter; - double b2 = param.compute_true_res ? blas::norm2(b) : 0.0; + auto b2 = param.compute_true_res ? blas::norm2(b) : vector(b.size(), 0.0); if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { // compute initial residual - mmdag.Expose()->M(xp, x); - if (param.compute_true_res && b2 == 0.0) - b2 = blas::xmyNorm(b, xp); - else - blas::xpay(b, -1.0, xp); + mmdag.Expose()->M(xe, x); + + if (param.compute_true_res) { + bool is_zero = true; + for (auto i = 0u; i < b2.size(); i++) { + is_zero = is_zero || b2[i] == 0.0; + if (b2[i] == 0.0 && !is_zero) errorQuda("Mixture of zero and non-zero sources not supported"); + } + if (is_zero) b2 = blas::xmyNorm(b, xe); + } else { + blas::xpay(b, -1.0, xe); + } // compute solution to residual equation - CG::operator()(yp, xp); + CG::operator()(ye, xe); - mmdag.Expose()->Mdag(xp, yp); + mmdag.Expose()->Mdag(xe, ye); // compute full solution - blas::xpy(xp, x); + blas::xpy(xe, x); } else { - CG::operator()(yp, b); - mmdag.Expose()->Mdag(x, yp); + CG::operator()(ye, b); + mmdag.Expose()->Mdag(x, ye); } if (param.compute_true_res || (param.use_init_guess && param.return_residual)) { // compute the true residual - mmdag.Expose()->M(xp, x); - blas::xpay(b, -1.0, xp); // xp now holds the residual + mmdag.Expose()->M(xe, x); + blas::xpay(b, -1.0, xe); // xe now holds the residual - double r2; + vector r2(b2.size()); if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - double3 h3 = blas::HeavyQuarkResidualNorm(x, xp); - r2 = h3.y; - param.true_res_hq = sqrt(h3.z); + auto hq = blas::HeavyQuarkResidualNorm(x, xe); + for (auto i = 0u; i < b.size(); i++) { + param.true_res_hq[i] = sqrt(hq[i].z); + r2[i] = hq[i].y; + } } else { - r2 = blas::norm2(xp); + r2 = blas::norm2(xe); } - param.true_res = sqrt(r2 / b2); - PrintSummary("CA-CGNE", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); + for (auto i = 0u; i < b.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); + PrintSummary("CGNE", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); } } @@ -152,13 +161,13 @@ namespace quda { { } - void CGNR::create(ColorSpinorField &x, const ColorSpinorField &b) + void CGNR::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); if (!init) { - ColorSpinorParam csParam(b); + ColorSpinorParam csParam(b[0]); csParam.create = QUDA_ZERO_FIELD_CREATE; - br = ColorSpinorField(csParam); + resize(br, b.size(), csParam); init = true; } } @@ -171,7 +180,7 @@ namespace quda { } // CGNR: Mdag M x = Mdag b is solved. - void CGNR::operator()(ColorSpinorField &x, const ColorSpinorField &b) + void CGNR::operator()(cvector_ref &x, cvector_ref &b) { if (param.maxiter == 0 || param.Nsteps == 0) { if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); @@ -181,10 +190,15 @@ namespace quda { create(x, b); const int iter0 = param.iter; - double b2 = 0.0; + vector b2(b.size(), 0.0); if (param.compute_true_res) { b2 = blas::norm2(b); - if (b2 == 0.0) { // compute initial residual vector + bool is_zero = true; + for (auto i = 0u; i < b2.size(); i++) { + is_zero = is_zero && b2[i] == 0.0; + if (b2[i] == 0.0 && !is_zero) errorQuda("Mixture of zero and non-zero sources not supported"); + } + if (is_zero) { // compute initial residual vector mdagm.Expose()->M(br, x); b2 = blas::norm2(br); } @@ -199,15 +213,17 @@ namespace quda { blas::xpay(b, -1.0, br); // br now holds the residual if (param.compute_true_res) { - double r2; + vector r2(b.size()); if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - double3 h3 = blas::HeavyQuarkResidualNorm(x, br); - r2 = h3.y; - param.true_res_hq = sqrt(h3.z); + auto hq = blas::HeavyQuarkResidualNorm(x, br); + for (auto i = 0u; i < b.size(); i++) { + param.true_res_hq[i] = sqrt(hq[i].z); + r2[i] = hq[i].y; + } } else { r2 = blas::norm2(br); } - param.true_res = sqrt(r2 / b2); + for (auto i = 0u; i < b.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); PrintSummary("CGNR", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); } } From d5c6708ba10e5ef0fc73e0bb9ff9ccd9fb3cef94 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 10 Sep 2024 14:18:24 -0700 Subject: [PATCH 068/103] CG3 is now MRHS --- include/invert_quda.h | 25 ++++----- lib/inv_cg3_quda.cpp | 120 +++++++++++++++++++++++------------------- lib/inv_cg_quda.cpp | 4 +- 3 files changed, 79 insertions(+), 70 deletions(-) diff --git a/include/invert_quda.h b/include/invert_quda.h index 9302d968ac..d64e1cc2a9 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -877,14 +877,14 @@ namespace quda { { private: - ColorSpinorField y; - ColorSpinorField r; - ColorSpinorField tmp; - ColorSpinorField ArS; - ColorSpinorField rS; - ColorSpinorField xS; - ColorSpinorField xS_old; - ColorSpinorField rS_old; + std::vector y; + std::vector r; + std::vector tmp; + std::vector ArS; + std::vector rS; + std::vector xS; + std::vector xS_old; + std::vector rS_old; bool init = false; /** @@ -892,17 +892,12 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: CG3(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m); - virtual void operator()(cvector_ref &out, cvector_ref &in) override - { - for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); - } - - void operator()(ColorSpinorField &out, const ColorSpinorField &in); + virtual void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual vector from the prior solve diff --git a/lib/inv_cg3_quda.cpp b/lib/inv_cg3_quda.cpp index eb2564ba7d..61d9120648 100644 --- a/lib/inv_cg3_quda.cpp +++ b/lib/inv_cg3_quda.cpp @@ -165,25 +165,30 @@ namespace quda { } } - void CG3::create(ColorSpinorField &x, const ColorSpinorField &b) + void CG3::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { - ColorSpinorParam csParam(b); + if (!init || r.size() != b.size()) { + ColorSpinorParam csParam(b[0]); csParam.create = QUDA_ZERO_FIELD_CREATE; - r = ColorSpinorField(csParam); - y = ColorSpinorField(csParam); + resize(r, b.size(), csParam); + resize(y, b.size(), csParam); // Sloppy fields - const bool mixed_precision = (param.precision != param.precision_sloppy); csParam.setPrecision(param.precision_sloppy); - ArS = ColorSpinorField(csParam); - rS_old = ColorSpinorField(csParam); - rS = mixed_precision ? ColorSpinorField(csParam) : r.create_alias(); - xS = mixed_precision ? ColorSpinorField(csParam) : x.create_alias(); - xS_old = mixed_precision ? ColorSpinorField(csParam) : y.create_alias(); - tmp = ColorSpinorField(csParam); + resize(ArS, b.size(), csParam); + resize(rS_old, b.size(), csParam); + if (param.precision != param.precision_sloppy) { + resize(rS, b.size(), csParam); + resize(xS, b.size(), csParam); + resize(xS_old, b.size(), csParam); + } else { + create_alias(rS, r); + create_alias(xS, x); + create_alias(xS_old, y); + } + resize(tmp, b.size(), csParam); init = true; } @@ -198,26 +203,21 @@ namespace quda { return r; } - void CG3::operator()(ColorSpinorField &x, const ColorSpinorField &b) + void CG3::operator()(cvector_ref &x, cvector_ref &b) { getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); // Check to see that we're not trying to invert on a zero-field source - double b2 = blas::norm2(b); - if (b2 == 0 - && (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO || param.use_init_guess == QUDA_USE_INIT_GUESS_NO)) { + auto b2 = blas::norm2(b); + if (is_zero_src(x, b, b2)) { getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); - printfQuda("Warning: inverting on zero-field source\n"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; return; } const bool mixed_precision = (param.precision != param.precision_sloppy); create(x, b); - double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; @@ -233,17 +233,18 @@ namespace quda { // these are only used if we use the heavy_quark_res const int hqmaxresIncrease = maxResIncrease + 1; int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual - double heavy_quark_res = 0.0; // heavy quark residual - double heavy_quark_res_old = 0.0; // heavy quark residual + vector heavy_quark_res(b.size(), 0.0); // heavy quark residual + vector heavy_quark_res_old(b.size(), 0.0); // heavy quark residual int hqresIncrease = 0; bool L2breakdown = false; // compute initial residual depending on whether we have an initial guess or not - double r2; + vector r2; if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { mat(r, x); r2 = blas::xmyNorm(b, r); - if (b2 == 0) b2 = r2; + for (auto i = 0u; i < b.size(); i++) + if (b2[i] == 0) b2[i] = r2[i]; if (mixed_precision) { blas::copy(y, x); blas::zero(xS); @@ -260,7 +261,8 @@ namespace quda { blas::copy(rS, r); if (use_heavy_quark_res) { - heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < hq.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); heavy_quark_res_old = heavy_quark_res; } @@ -268,8 +270,8 @@ namespace quda { if (convergence(r2, heavy_quark_res, stop, param.tol_hq)) return; getProfile().TPSTART(QUDA_PROFILE_COMPUTE); - double r2_old = r2; - double rNorm = sqrt(r2); + auto r2_old = r2; + double rNorm = sqrt(r2[0]); double r0Norm = rNorm; double maxrx = rNorm; double maxrr = rNorm; @@ -278,21 +280,23 @@ namespace quda { int k = 0; PrintStats("CG3", k, r2, b2, heavy_quark_res); - double rho = 1.0, gamma = 1.0; + vector rho(b.size(), 1.0); + vector gamma(b.size(), 1.0); while ( !convergence(r2, heavy_quark_res, stop, param.tol_hq) && k < param.maxiter) { matSloppy(ArS, rS); - double gamma_old = gamma; - double rAr = blas::reDotProduct(rS,ArS); - gamma = r2 / rAr; + auto gamma_old = gamma; + auto rAr = blas::reDotProduct(rS, ArS); + for (auto i = 0u; i < b.size(); i++) gamma[i] = r2[i] / rAr[i]; // CG3 step if (k == 0 || restart) { // First iteration r2 = blas::quadrupleCG3InitNorm(gamma, xS, rS, xS_old, rS_old, ArS); restart = false; } else { - rho = rho/(rho-(gamma/gamma_old)*(r2/r2_old)); + for (auto i = 0u; i < rho.size(); i++) + rho[i] = rho[i] / (rho[i] - (gamma[i] / gamma_old[i]) * (r2[i] / r2_old[i])); r2_old = r2; r2 = blas::quadrupleCG3UpdateNorm(gamma, rho, xS, rS, xS_old, rS_old, ArS); } @@ -302,15 +306,17 @@ namespace quda { if (use_heavy_quark_res && k % heavy_quark_check == 0) { heavy_quark_res_old = heavy_quark_res; if (mixed_precision) { - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xS, y, rS).z); + auto hq = blas::xpyHeavyQuarkResidualNorm(xS, y, rS); + for (auto i = 0u; i < b2.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); } else { - heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(xS, rS).z); + auto hq = blas::HeavyQuarkResidualNorm(xS, rS); + for (auto i = 0u; i < b2.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); } } // reliable update conditions if (mixed_precision) { - rNorm = sqrt(r2); + rNorm = sqrt(r2[0]); if (rNorm > maxrx) maxrx = rNorm; if (rNorm > maxrr) maxrr = rNorm; bool update = (rNorm < delta*r0Norm && r0Norm <= maxrx); // condition for x @@ -330,13 +336,14 @@ namespace quda { blas::xpy(x, y); mat(r, y); r2 = blas::xmyNorm(b, r); - param.true_res = sqrt(r2 / b2); + for (auto i = 0u; i < b2.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); if (use_heavy_quark_res) { - heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z); + auto hq = blas::HeavyQuarkResidualNorm(y, r); + for (auto i = 0u; i < b2.size(); i++) heavy_quark_res = sqrt(hq[i].z); param.true_res_hq = heavy_quark_res; } - rNorm = sqrt(r2); - r0Norm = sqrt(r2); + rNorm = sqrt(r2[0]); + r0Norm = sqrt(r2[0]); maxrr = rNorm; maxrx = rNorm; // we update sloppy and old fields @@ -344,19 +351,20 @@ namespace quda { blas::copy(rS, r); blas::axpy(-1., xS, xS_old); // we preserve the orthogonality between the previous residual and the new - Complex rr_old = blas::cDotProduct(rS, rS_old); - r2_old = blas::caxpyNorm(-rr_old/r2, rS, rS_old); + auto rr_old = blas::cDotProduct(rS, rS_old); + for (auto i = 0u; i < r2.size(); i++) rr_old[i] /= r2[i]; + r2_old = blas::caxpyNorm(-rr_old, rS, rS_old); blas::zero(xS); } } // break-out check if we have reached the limit of the precision - if (sqrt(r2) > r0Norm) { + if (sqrt(r2[0]) > r0Norm) { resIncrease++; resIncreaseTotal++; warningQuda( "CG3: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", - sqrt(r2), r0Norm, resIncreaseTotal); + sqrt(r2[0]), r0Norm, resIncreaseTotal); if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) { if (use_heavy_quark_res) { L2breakdown = true; @@ -376,9 +384,10 @@ namespace quda { warningQuda("CG3: Restarting without reliable updates for heavy-quark residual"); restart = true; L2breakdown = false; - if (heavy_quark_res > heavy_quark_res_old) { + if (heavy_quark_res[0] > heavy_quark_res_old[0]) { hqresIncrease++; - warningQuda("CG3: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", heavy_quark_res, heavy_quark_res_old); + warningQuda("CG3: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", + heavy_quark_res[0], heavy_quark_res_old[0]); // break out if we do not improve here anymore if (hqresIncrease > hqmaxresIncrease) { warningQuda("CG3: solver exiting due to too many heavy quark residual norm increases"); @@ -390,22 +399,23 @@ namespace quda { if (convergence(r2, heavy_quark_res, stop, param.tol_hq)) { mat(r, x); r2 = blas::xmyNorm(b, r); - r0Norm = sqrt(r2); + r0Norm = sqrt(r2[0]); // we update sloppy and old fields if (!convergence(r2, heavy_quark_res, stop, param.tol_hq)) { // we preserve the orthogonality between the previous residual and the new - Complex rr_old = blas::cDotProduct(rS, rS_old); - r2_old = blas::caxpyNorm(-rr_old/r2, rS, rS_old); + auto rr_old = blas::cDotProduct(rS, rS_old); + for (auto i = 0u; i < r2.size(); i++) rr_old[i] /= r2[i]; + r2_old = blas::caxpyNorm(-rr_old, rS, rS_old); } } // break-out check if we have reached the limit of the precision - if (sqrt(r2) > r0Norm) { + if (sqrt(r2[0]) > r0Norm) { resIncrease++; resIncreaseTotal++; warningQuda( "CG3: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", - sqrt(r2), r0Norm, resIncreaseTotal); + sqrt(r2[0]), r0Norm, resIncreaseTotal); if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) { warningQuda("CG3: solver exiting due to too many true residual norm increases"); break; @@ -427,8 +437,12 @@ namespace quda { // compute the true residuals if (!mixed_precision && param.compute_true_res) { mat(r, x); - param.true_res = sqrt(blas::xmyNorm(b, r) / b2); - if (use_heavy_quark_res) param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + r2 = blas::xmyNorm(b, r); + for (auto i = 0u; i < b.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) param.true_res_hq[i] = sqrt(hq[i].z); + } } PrintSummary("CG3", k, r2, b2, stop, param.tol_hq); diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index b6c7c6a593..97a3481085 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -73,7 +73,7 @@ namespace quda { void CGNE::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { + if (!init || xe.size() != b.size()) { ColorSpinorParam csParam(x[0]); csParam.create = QUDA_NULL_FIELD_CREATE; resize(xe, b.size(), csParam); @@ -164,7 +164,7 @@ namespace quda { void CGNR::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { + if (!init || br.size() != b.size()) { ColorSpinorParam csParam(b[0]); csParam.create = QUDA_ZERO_FIELD_CREATE; resize(br, b.size(), csParam); From 2e52dad62066ebe6194f67303391885d40641252 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 11 Sep 2024 12:36:42 -0700 Subject: [PATCH 069/103] Remove derived CGNR and CGNE specialziations for CG/CA-CG/CG3: we now have single implementations of these, which use a solver factor to call the correct CG method --- include/invert_quda.h | 161 +--------------------------------------- lib/CMakeLists.txt | 1 + lib/inv_ca_cg.cpp | 153 -------------------------------------- lib/inv_cg3_quda.cpp | 150 ------------------------------------- lib/inv_cg_quda.cpp | 169 ------------------------------------------ lib/inv_cgne.cpp | 103 +++++++++++++++++++++++++ lib/inv_cgnr.cpp | 90 ++++++++++++++++++++++ lib/solver.cpp | 8 +- 8 files changed, 202 insertions(+), 633 deletions(-) create mode 100644 lib/inv_cgne.cpp create mode 100644 lib/inv_cgnr.cpp diff --git a/include/invert_quda.h b/include/invert_quda.h index d64e1cc2a9..640c90c5f7 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -804,14 +804,14 @@ namespace quda { void hqsolve(cvector_ref &out, cvector_ref &in); }; - class CGNE : public CG + class CGNE : public Solver { - private: DiracMMdag mmdag; DiracMMdag mmdagSloppy; DiracMMdag mmdagPrecon; DiracMMdag mmdagEig; + std::unique_ptr cg; std::vector xe; std::vector ye; bool init = false; @@ -839,14 +839,14 @@ namespace quda { virtual QudaInverterType getInverterType() const final { return QUDA_CGNE_INVERTER; } }; - class CGNR : public CG + class CGNR : public Solver { - private: DiracMdagM mdagm; DiracMdagM mdagmSloppy; DiracMdagM mdagmPrecon; DiracMdagM mdagmEig; + std::unique_ptr cg; std::vector br; bool init = false; @@ -909,81 +909,6 @@ namespace quda { virtual QudaInverterType getInverterType() const override { return QUDA_CG3_INVERTER; } }; - class CG3NE : public CG3 - { - - private: - DiracMMdag mmdag; - DiracMMdag mmdagSloppy; - DiracMMdag mmdagPrecon; - ColorSpinorField xp; - ColorSpinorField yp; - bool init = false; - - /** - @brief Initiate the fields needed by the solver - @param[in] x Solution vector - @param[in] b Source vector - */ - void create(ColorSpinorField &x, const ColorSpinorField &b); - - public: - CG3NE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m); - - void operator()(cvector_ref &out, cvector_ref &in) override - { - for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); - } - - void operator()(ColorSpinorField &out, const ColorSpinorField &in); - - /** - @return Return the residual vector from the prior solve - */ - cvector_ref get_residual() override; - - virtual bool hermitian() const final { return false; } /** CG3NE is for any system */ - - virtual QudaInverterType getInverterType() const final { return QUDA_CG3NE_INVERTER; } - }; - - class CG3NR : public CG3 - { - - private: - DiracMdagM mdagm; - DiracMdagM mdagmSloppy; - DiracMdagM mdagmPrecon; - ColorSpinorField br; - bool init = false; - - /** - @brief Initiate the fields needed by the solver - @param[in] x Solution vector - @param[in] b Source vector - */ - void create(ColorSpinorField &x, const ColorSpinorField &b); - - public: - CG3NR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m); - - void operator()(cvector_ref &out, cvector_ref &in) override - { - for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); - } - - void operator()(ColorSpinorField &out, const ColorSpinorField &in); - - /** - @return Return the residual vector from the prior solve - */ - cvector_ref get_residual() override; - - virtual bool hermitian() const final { return false; } /** CG3NR is for any system */ - - virtual QudaInverterType getInverterType() const final { return QUDA_CG3NR_INVERTER; } - }; - class PreconCG : public Solver { private: std::shared_ptr K; @@ -1383,84 +1308,6 @@ namespace quda { virtual QudaInverterType getInverterType() const override { return QUDA_CA_CG_INVERTER; } }; - class CACGNE : public CACG { - - private: - DiracMMdag mmdag; - DiracMMdag mmdagSloppy; - DiracMMdag mmdagPrecon; - DiracMMdag mmdagEig; - ColorSpinorField xp; - ColorSpinorField yp; - bool init = false; - - /** - @brief Initiate the fields needed by the solver - @param[in] x Solution vector - @param[in] b Source vector - */ - void create(ColorSpinorField &x, const ColorSpinorField &b); - - public: - CACGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m); - - void operator()(cvector_ref &out, cvector_ref &in) override - { - for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); - } - - void operator()(ColorSpinorField &out, const ColorSpinorField &in); - - /** - @return Return the residual vector from the prior solve - */ - cvector_ref get_residual() override; - - virtual bool hermitian() const final { return false; } /** CA-CGNE is for any linear system */ - - virtual QudaInverterType getInverterType() const final { return QUDA_CA_CGNE_INVERTER; } - }; - - class CACGNR : public CACG - { - - private: - DiracMdagM mdagm; - DiracMdagM mdagmSloppy; - DiracMdagM mdagmPrecon; - DiracMdagM mdagmEig; - ColorSpinorField br; - bool init = false; - - /** - @brief Initiate the fields needed by the solver - @param[in] x Solution vector - @param[in] b Source vector - */ - void create(ColorSpinorField &x, const ColorSpinorField &b); - - public: - CACGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m); - - void operator()(cvector_ref &out, cvector_ref &in) override - { - for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); - } - - void operator()(ColorSpinorField &out, const ColorSpinorField &in); - - /** - @return Return the residual vector from the prior solve - */ - cvector_ref get_residual() override; - - virtual bool hermitian() const final { return false; } /** CA-CGNR is for any linear system */ - - virtual QudaInverterType getInverterType() const final { return QUDA_CA_CGNR_INVERTER; } - }; - /** @brief Communication-avoiding GCR solver. This solver does un-preconditioned GCR, first building up a polynomial in the diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 0b08bb37f4..950ac83a09 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -29,6 +29,7 @@ set (QUDA_OBJS inv_multi_cg_quda.cpp inv_eigcg_quda.cpp gauge_ape.cu gauge_stout.cu gauge_hyp.cu gauge_wilson_flow.cu gauge_plaq.cu gauge_laplace.cpp gauge_observable.cpp + inv_cgnr.cpp inv_cgne.cpp inv_cg3_quda.cpp inv_ca_gcr.cpp inv_ca_cg.cpp inv_gcr_quda.cpp inv_mr_quda.cpp inv_sd_quda.cpp inv_pcg_quda.cpp inv_mre.cpp interface_quda.cpp util_quda.cpp diff --git a/lib/inv_ca_cg.cpp b/lib/inv_ca_cg.cpp index 66f3800afb..c50f18dc31 100644 --- a/lib/inv_ca_cg.cpp +++ b/lib/inv_ca_cg.cpp @@ -27,159 +27,6 @@ namespace quda if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_FREE); } - CACGNE::CACGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m) : - CACG(mmdag, mmdagSloppy, mmdagPrecon, mmdagEig, param), - mmdag(mat.Expose()), - mmdagSloppy(matSloppy.Expose()), - mmdagPrecon(matPrecon.Expose()), - mmdagEig(matEig.Expose()) - { - } - - void CACGNE::create(ColorSpinorField &x, const ColorSpinorField &b) - { - Solver::create(x, b); - if (!init) { - ColorSpinorParam csParam(x); - csParam.create = QUDA_NULL_FIELD_CREATE; - xp = ColorSpinorField(csParam); - csParam.create = QUDA_ZERO_FIELD_CREATE; - yp = ColorSpinorField(csParam); - init = true; - } - } - - cvector_ref CACGNE::get_residual() - { - if (!init) errorQuda("No residual vector present"); - if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); - // CG residual will match the CGNE residual (FIXME: but only with zero initial guess?) - return param.use_init_guess ? xp : CACG::get_residual(); - } - - // CACGNE: M Mdag y = b is solved; x = Mdag y is returned as solution. - void CACGNE::operator()(ColorSpinorField &x, const ColorSpinorField &b) - { - if (param.maxiter == 0 || param.Nsteps == 0) { - if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); - return; - } - - create(x, b); - - const int iter0 = param.iter; - double b2 = param.compute_true_res ? blas::norm2(b) : 0.0; - - if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { - // compute initial residual - mmdag.Expose()->M(xp, x); - if (param.compute_true_res && b2 == 0.0) - b2 = blas::xmyNorm(b, xp); - else - blas::xpay(b, -1.0, xp); - - // compute solution to residual equation - CACG::operator()(yp, xp); - - mmdag.Expose()->Mdag(xp, yp); - - // compute full solution - blas::xpy(xp, x); - } else { - CACG::operator()(yp, b); - mmdag.Expose()->Mdag(x, yp); - } - - if (param.compute_true_res || (param.use_init_guess && param.return_residual)) { - // compute the true residual - mmdag.Expose()->M(xp, x); - blas::xpay(b, -1.0, xp); // xp now holds the residual - - double r2; - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - double3 h3 = blas::HeavyQuarkResidualNorm(x, xp); - r2 = h3.y; - param.true_res_hq = sqrt(h3.z); - } else { - r2 = blas::norm2(xp); - } - param.true_res = sqrt(r2 / b2); - PrintSummary("CA-CGNE", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); - } - } - - CACGNR::CACGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m) : - CACG(mdagm, mdagmSloppy, mdagmPrecon, mdagmEig, param), - mdagm(mat.Expose()), - mdagmSloppy(matSloppy.Expose()), - mdagmPrecon(matPrecon.Expose()), - mdagmEig(matEig.Expose()) - { - } - - void CACGNR::create(ColorSpinorField &x, const ColorSpinorField &b) - { - Solver::create(x, b); - if (!init) { - ColorSpinorParam csParam(b); - csParam.create = QUDA_ZERO_FIELD_CREATE; - br = ColorSpinorField(csParam); - init = true; - } - } - - cvector_ref CACGNR::get_residual() - { - if (!init) errorQuda("No residual vector present"); - if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); - return br; - } - - // CACGNR: Mdag M x = Mdag b is solved. - void CACGNR::operator()(ColorSpinorField &x, const ColorSpinorField &b) - { - if (param.maxiter == 0 || param.Nsteps == 0) { - if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); - return; - } - - create(x, b); - - const int iter0 = param.iter; - double b2 = 0.0; - if (param.compute_true_res) { - b2 = blas::norm2(b); - if (b2 == 0.0) { // compute initial residual vector - mdagm.Expose()->M(br, x); - b2 = blas::norm2(br); - } - } - - mdagm.Expose()->Mdag(br, b); - CACG::operator()(x, br); - - if (param.compute_true_res || param.return_residual) { - // compute the true residual - mdagm.Expose()->M(br, x); - blas::xpay(b, -1.0, br); // br now holds the residual - - if (param.compute_true_res) { - double r2; - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - double3 h3 = blas::HeavyQuarkResidualNorm(x, br); - r2 = h3.y; - param.true_res_hq = sqrt(h3.z); - } else { - r2 = blas::norm2(br); - } - param.true_res = sqrt(r2 / b2); - PrintSummary("CA-CGNR", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); - } - } - } - void CACG::create(ColorSpinorField &x, const ColorSpinorField &b) { Solver::create(x, b); diff --git a/lib/inv_cg3_quda.cpp b/lib/inv_cg3_quda.cpp index 61d9120648..e19fd508bc 100644 --- a/lib/inv_cg3_quda.cpp +++ b/lib/inv_cg3_quda.cpp @@ -15,156 +15,6 @@ namespace quda { { } - CG3NE::CG3NE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m) : - CG3(mmdag, mmdagSloppy, mmdagPrecon, param), - mmdag(mat.Expose()), - mmdagSloppy(matSloppy.Expose()), - mmdagPrecon(matPrecon.Expose()) - { - } - - void CG3NE::create(ColorSpinorField &x, const ColorSpinorField &b) - { - Solver::create(x, b); - if (!init) { - ColorSpinorParam csParam(x); - csParam.create = QUDA_NULL_FIELD_CREATE; - xp = ColorSpinorField(csParam); - csParam.create = QUDA_ZERO_FIELD_CREATE; - yp = ColorSpinorField(csParam); - init = true; - } - } - - cvector_ref CG3NE::get_residual() - { - if (!init) errorQuda("No residual vector present"); - if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); - // CG3 residual will match the CG3NE residual (FIXME: but only with zero initial guess?) - return param.use_init_guess ? xp : CG3::get_residual(); - } - - // CG3NE: M Mdag y = b is solved; x = Mdag y is returned as solution. - void CG3NE::operator()(ColorSpinorField &x, const ColorSpinorField &b) - { - if (param.maxiter == 0 || param.Nsteps == 0) { - if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); - return; - } - - create(x, b); - - const int iter0 = param.iter; - double b2 = param.compute_true_res ? blas::norm2(b) : 0.0; - - if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { - // compute initial residual - mmdag.Expose()->M(xp, x); - if (param.compute_true_res && b2 == 0.0) - b2 = blas::xmyNorm(b, xp); - else - blas::xpay(b, -1.0, xp); - - // compute solution to residual equation - CG3::operator()(yp, xp); - - mmdag.Expose()->Mdag(xp, yp); - - // compute full solution - blas::xpy(xp, x); - } else { - CG3::operator()(yp, b); - mmdag.Expose()->Mdag(x, yp); - } - - if (param.compute_true_res || (param.use_init_guess && param.return_residual)) { - // compute the true residual - mmdag.Expose()->M(xp, x); - blas::xpay(b, -1.0, xp); // xp now holds the residual - - double r2; - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - double3 h3 = blas::HeavyQuarkResidualNorm(x, xp); - r2 = h3.y; - param.true_res_hq = sqrt(h3.z); - } else { - r2 = blas::norm2(xp); - } - param.true_res = sqrt(r2 / b2); - PrintSummary("CG3NE", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); - } - } - - CG3NR::CG3NR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m) : - CG3(mdagm, mdagmSloppy, mdagmPrecon, param), - mdagm(mat.Expose()), - mdagmSloppy(matSloppy.Expose()), - mdagmPrecon(matPrecon.Expose()) - { - } - - void CG3NR::create(ColorSpinorField &x, const ColorSpinorField &b) - { - Solver::create(x, b); - - if (!init) { - ColorSpinorParam csParam(b); - csParam.create = QUDA_ZERO_FIELD_CREATE; - br = ColorSpinorField(csParam); - init = true; - } - } - - cvector_ref CG3NR::get_residual() - { - if (!init) errorQuda("No residual vector present"); - if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); - return br; - } - - // CG3NR: Mdag M x = Mdag b is solved. - void CG3NR::operator()(ColorSpinorField &x, const ColorSpinorField &b) - { - if (param.maxiter == 0 || param.Nsteps == 0) { - if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); - return; - } - - create(x, b); - - const int iter0 = param.iter; - double b2 = 0.0; - if (param.compute_true_res) { - b2 = blas::norm2(b); - if (b2 == 0.0) { // compute initial residual vector - mdagm.Expose()->M(br, x); - b2 = blas::norm2(br); - } - } - - mdagm.Expose()->Mdag(br, b); - CG3::operator()(x, br); - - if (param.compute_true_res || param.return_residual) { - // compute the true residual - mdagm.Expose()->M(br, x); - blas::xpay(b, -1.0, br); // br now holds the residual - - if (param.compute_true_res) { - double r2; - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - double3 h3 = blas::HeavyQuarkResidualNorm(x, br); - r2 = h3.y; - param.true_res_hq = sqrt(h3.z); - } else { - r2 = blas::norm2(br); - } - param.true_res = sqrt(r2 / b2); - PrintSummary("CG3NR", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); - } - } - } - void CG3::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index 97a3481085..5531639df1 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -60,175 +60,6 @@ namespace quda { if (!param.use_sloppy_partial_accumulator) create_alias(x_sloppy, x); } - CGNE::CGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m) : - CG(mmdag, mmdagSloppy, mmdagPrecon, mmdagEig, param), - mmdag(mat.Expose()), - mmdagSloppy(matSloppy.Expose()), - mmdagPrecon(matPrecon.Expose()), - mmdagEig(matEig.Expose()) - { - } - - void CGNE::create(cvector_ref &x, cvector_ref &b) - { - Solver::create(x, b); - if (!init || xe.size() != b.size()) { - ColorSpinorParam csParam(x[0]); - csParam.create = QUDA_NULL_FIELD_CREATE; - resize(xe, b.size(), csParam); - csParam.create = QUDA_ZERO_FIELD_CREATE; - resize(ye, b.size(), csParam); - init = true; - } - } - - cvector_ref CGNE::get_residual() - { - if (!init) errorQuda("No residual vector present"); - if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); - // CG residual will match the CGNE residual (FIXME: but only with zero initial guess?) - return param.use_init_guess ? xe : CG::get_residual(); - } - - // CGNE: M Mdag y = b is solved; x = Mdag y is returned as solution. - void CGNE::operator()(cvector_ref &x, cvector_ref &b) - { - if (param.maxiter == 0 || param.Nsteps == 0) { - if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); - return; - } - - create(x, b); - - const int iter0 = param.iter; - auto b2 = param.compute_true_res ? blas::norm2(b) : vector(b.size(), 0.0); - - if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { - // compute initial residual - mmdag.Expose()->M(xe, x); - - if (param.compute_true_res) { - bool is_zero = true; - for (auto i = 0u; i < b2.size(); i++) { - is_zero = is_zero || b2[i] == 0.0; - if (b2[i] == 0.0 && !is_zero) errorQuda("Mixture of zero and non-zero sources not supported"); - } - if (is_zero) b2 = blas::xmyNorm(b, xe); - } else { - blas::xpay(b, -1.0, xe); - } - - // compute solution to residual equation - CG::operator()(ye, xe); - - mmdag.Expose()->Mdag(xe, ye); - - // compute full solution - blas::xpy(xe, x); - } else { - CG::operator()(ye, b); - mmdag.Expose()->Mdag(x, ye); - } - - if (param.compute_true_res || (param.use_init_guess && param.return_residual)) { - // compute the true residual - mmdag.Expose()->M(xe, x); - blas::xpay(b, -1.0, xe); // xe now holds the residual - - vector r2(b2.size()); - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - auto hq = blas::HeavyQuarkResidualNorm(x, xe); - for (auto i = 0u; i < b.size(); i++) { - param.true_res_hq[i] = sqrt(hq[i].z); - r2[i] = hq[i].y; - } - } else { - r2 = blas::norm2(xe); - } - for (auto i = 0u; i < b.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); - PrintSummary("CGNE", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); - } - } - - CGNR::CGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m) : - CG(mdagm, mdagmSloppy, mdagmPrecon, mdagmEig, param), - mdagm(mat.Expose()), - mdagmSloppy(matSloppy.Expose()), - mdagmPrecon(matPrecon.Expose()), - mdagmEig(matEig.Expose()) - { - } - - void CGNR::create(cvector_ref &x, cvector_ref &b) - { - Solver::create(x, b); - if (!init || br.size() != b.size()) { - ColorSpinorParam csParam(b[0]); - csParam.create = QUDA_ZERO_FIELD_CREATE; - resize(br, b.size(), csParam); - init = true; - } - } - - cvector_ref CGNR::get_residual() - { - if (!init) errorQuda("No residual vector present"); - if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); - return br; - } - - // CGNR: Mdag M x = Mdag b is solved. - void CGNR::operator()(cvector_ref &x, cvector_ref &b) - { - if (param.maxiter == 0 || param.Nsteps == 0) { - if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); - return; - } - - create(x, b); - - const int iter0 = param.iter; - vector b2(b.size(), 0.0); - if (param.compute_true_res) { - b2 = blas::norm2(b); - bool is_zero = true; - for (auto i = 0u; i < b2.size(); i++) { - is_zero = is_zero && b2[i] == 0.0; - if (b2[i] == 0.0 && !is_zero) errorQuda("Mixture of zero and non-zero sources not supported"); - } - if (is_zero) { // compute initial residual vector - mdagm.Expose()->M(br, x); - b2 = blas::norm2(br); - } - } - - mdagm.Expose()->Mdag(br, b); - CG::operator()(x, br); - - if (param.compute_true_res || param.return_residual) { - // compute the true residual - mdagm.Expose()->M(br, x); - blas::xpay(b, -1.0, br); // br now holds the residual - - if (param.compute_true_res) { - vector r2(b.size()); - if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - auto hq = blas::HeavyQuarkResidualNorm(x, br); - for (auto i = 0u; i < b.size(); i++) { - param.true_res_hq[i] = sqrt(hq[i].z); - r2[i] = hq[i].y; - } - } else { - r2 = blas::norm2(br); - } - for (auto i = 0u; i < b.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); - PrintSummary("CGNR", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); - } - } - } - void CG::operator()(cvector_ref &x, cvector_ref &b, cvector_ref &p_init, cvector &r2_old_init) { diff --git a/lib/inv_cgne.cpp b/lib/inv_cgne.cpp new file mode 100644 index 0000000000..9e7f19c97f --- /dev/null +++ b/lib/inv_cgne.cpp @@ -0,0 +1,103 @@ +#include "invert_quda.h" + +namespace quda +{ + + CGNE::CGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, + const DiracMatrix &matEig, SolverParam ¶m) : + Solver(mat, matSloppy, matPrecon, matEig, param), + mmdag(mat.Expose()), + mmdagSloppy(matSloppy.Expose()), + mmdagPrecon(matPrecon.Expose()), + mmdagEig(matEig.Expose()) + { + switch (param.inv_type) { + case QUDA_CGNE_INVERTER: cg = std::make_unique(mmdag, mmdagSloppy, mmdagPrecon, mmdagEig, param); break; + case QUDA_CA_CGNE_INVERTER: cg = std::make_unique(mmdag, mmdagSloppy, mmdagPrecon, mmdagEig, param); break; + case QUDA_CG3NE_INVERTER: cg = std::make_unique(mmdag, mmdagSloppy, mmdagPrecon, param); break; + default: errorQuda("Unexpected CG solver type %d", param.inv_type); + } + } + + void CGNE::create(cvector_ref &x, cvector_ref &b) + { + Solver::create(x, b); + if (!init || xe.size() != b.size()) { + ColorSpinorParam csParam(x[0]); + csParam.create = QUDA_NULL_FIELD_CREATE; + resize(xe, b.size(), csParam); + csParam.create = QUDA_ZERO_FIELD_CREATE; + resize(ye, b.size(), csParam); + init = true; + } + } + + cvector_ref CGNE::get_residual() + { + if (!init) errorQuda("No residual vector present"); + if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); + // CG residual will match the CGNE residual (FIXME: but only with zero initial guess?) + return param.use_init_guess ? xe : cg->get_residual(); + } + + // CGNE: M Mdag y = b is solved; x = Mdag y is returned as solution. + void CGNE::operator()(cvector_ref &x, cvector_ref &b) + { + if (param.maxiter == 0 || param.Nsteps == 0) { + if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); + return; + } + + create(x, b); + + const int iter0 = param.iter; + auto b2 = param.compute_true_res ? blas::norm2(b) : vector(b.size(), 0.0); + + if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { + // compute initial residual + mmdag.Expose()->M(xe, x); + + if (param.compute_true_res) { + bool is_zero = true; + for (auto i = 0u; i < b2.size(); i++) { + is_zero = is_zero || b2[i] == 0.0; + if (b2[i] == 0.0 && !is_zero) errorQuda("Mixture of zero and non-zero sources not supported"); + } + if (is_zero) b2 = blas::xmyNorm(b, xe); + } else { + blas::xpay(b, -1.0, xe); + } + + // compute solution to residual equation + cg->operator()(ye, xe); + + mmdag.Expose()->Mdag(xe, ye); + + // compute full solution + blas::xpy(xe, x); + } else { + cg->operator()(ye, b); + mmdag.Expose()->Mdag(x, ye); + } + + if (param.compute_true_res || (param.use_init_guess && param.return_residual)) { + // compute the true residual + mmdag.Expose()->M(xe, x); + blas::xpay(b, -1.0, xe); // xe now holds the residual + + vector r2(b2.size()); + if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { + auto hq = blas::HeavyQuarkResidualNorm(x, xe); + for (auto i = 0u; i < b.size(); i++) { + param.true_res_hq[i] = sqrt(hq[i].z); + r2[i] = hq[i].y; + } + } else { + r2 = blas::norm2(xe); + } + for (auto i = 0u; i < b.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); + PrintSummary("CGNE", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); + } + } + +} // namespace quda diff --git a/lib/inv_cgnr.cpp b/lib/inv_cgnr.cpp new file mode 100644 index 0000000000..9dc7b67e0a --- /dev/null +++ b/lib/inv_cgnr.cpp @@ -0,0 +1,90 @@ +#include "invert_quda.h" + +namespace quda +{ + + CGNR::CGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, + const DiracMatrix &matEig, SolverParam ¶m) : + Solver(mat, mdagmSloppy, mdagmPrecon, mdagmEig, param), + mdagm(mat.Expose()), + mdagmSloppy(matSloppy.Expose()), + mdagmPrecon(matPrecon.Expose()), + mdagmEig(matEig.Expose()) + { + switch (param.inv_type) { + case QUDA_CGNR_INVERTER: cg = std::make_unique(mdagm, mdagmSloppy, mdagmPrecon, mdagmEig, param); break; + case QUDA_CA_CGNR_INVERTER: cg = std::make_unique(mdagm, mdagmSloppy, mdagmPrecon, mdagmEig, param); break; + case QUDA_CG3NR_INVERTER: cg = std::make_unique(mdagm, mdagmSloppy, mdagmPrecon, param); break; + default: errorQuda("Unexpected CG solver type %d", param.inv_type); + } + } + + void CGNR::create(cvector_ref &x, cvector_ref &b) + { + Solver::create(x, b); + if (!init || br.size() != b.size()) { + ColorSpinorParam csParam(b[0]); + csParam.create = QUDA_ZERO_FIELD_CREATE; + resize(br, b.size(), csParam); + init = true; + } + } + + cvector_ref CGNR::get_residual() + { + if (!init) errorQuda("No residual vector present"); + if (!param.return_residual) errorQuda("SolverParam::return_residual not enabled"); + return br; + } + + // CGNR: Mdag M x = Mdag b is solved. + void CGNR::operator()(cvector_ref &x, cvector_ref &b) + { + if (param.maxiter == 0 || param.Nsteps == 0) { + if (param.use_init_guess == QUDA_USE_INIT_GUESS_NO) blas::zero(x); + return; + } + + create(x, b); + + const int iter0 = param.iter; + vector b2(b.size(), 0.0); + if (param.compute_true_res) { + b2 = blas::norm2(b); + bool is_zero = true; + for (auto i = 0u; i < b2.size(); i++) { + is_zero = is_zero && b2[i] == 0.0; + if (b2[i] == 0.0 && !is_zero) errorQuda("Mixture of zero and non-zero sources not supported"); + } + if (is_zero) { // compute initial residual vector + mdagm.Expose()->M(br, x); + b2 = blas::norm2(br); + } + } + + mdagm.Expose()->Mdag(br, b); + cg->operator()(x, br); + + if (param.compute_true_res || param.return_residual) { + // compute the true residual + mdagm.Expose()->M(br, x); + blas::xpay(b, -1.0, br); // br now holds the residual + + if (param.compute_true_res) { + vector r2(b.size()); + if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { + auto hq = blas::HeavyQuarkResidualNorm(x, br); + for (auto i = 0u; i < b.size(); i++) { + param.true_res_hq[i] = sqrt(hq[i].z); + r2[i] = hq[i].y; + } + } else { + r2 = blas::norm2(br); + } + for (auto i = 0u; i < b.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); + PrintSummary("CGNR", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); + } + } + } + +} // namespace quda diff --git a/lib/solver.cpp b/lib/solver.cpp index f32cf8a9fd..da1128af09 100644 --- a/lib/solver.cpp +++ b/lib/solver.cpp @@ -81,11 +81,11 @@ namespace quda { break; case QUDA_CA_CGNE_INVERTER: report("CA-CGNE"); - solver = new CACGNE(mat, matSloppy, matPrecon, matEig, param); + solver = new CGNE(mat, matSloppy, matPrecon, matEig, param); break; case QUDA_CA_CGNR_INVERTER: report("CA-CGNR"); - solver = new CACGNR(mat, matSloppy, matPrecon, matEig, param); + solver = new CGNR(mat, matSloppy, matPrecon, matEig, param); break; case QUDA_CA_GCR_INVERTER: report("CA-GCR"); @@ -148,11 +148,11 @@ namespace quda { break; case QUDA_CG3NE_INVERTER: report("CG3NE"); - solver = new CG3NE(mat, matSloppy, matPrecon, param); + solver = new CGNE(mat, matSloppy, matPrecon, matEig, param); break; case QUDA_CG3NR_INVERTER: report("CG3NR"); - solver = new CG3NR(mat, matSloppy, matPrecon, param); + solver = new CGNR(mat, matSloppy, matPrecon, matEig, param); break; default: errorQuda("Invalid solver type %d", param.inv_type); From 0803993c0909ef7b2b022ab8b7eb92f56b38b71d Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 11 Sep 2024 15:31:12 -0700 Subject: [PATCH 070/103] Optimize HQ in CA-GCR --- lib/inv_ca_gcr.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/lib/inv_ca_gcr.cpp b/lib/inv_ca_gcr.cpp index 53d71bdaee..3a74acb00b 100644 --- a/lib/inv_ca_gcr.cpp +++ b/lib/inv_ca_gcr.cpp @@ -253,9 +253,11 @@ namespace quda const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance const int maxResIncreaseTotal = param.max_res_increase_total; - std::vector heavy_quark_res(b.size()); // heavy quark residual - if (use_heavy_quark_res) - for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(blas::HeavyQuarkResidualNorm(x, r)[i].z); + std::vector heavy_quark_res(b.size(), 0.0); // heavy quark residual + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } int resIncrease = 0; int resIncreaseTotal = 0; @@ -330,8 +332,10 @@ namespace quda maxr_deflate = sqrt(r2[0]); } - if (use_heavy_quark_res) - for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(blas::HeavyQuarkResidualNorm(x, r)[i].z); + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } // break-out check if we have reached the limit of the precision if (r2 > r2_old) { From e1589e5079184a74eb579509ac7da48aaf688308 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 11 Sep 2024 15:55:20 -0700 Subject: [PATCH 071/103] CA-CG is now MRHS --- include/invert_quda.h | 47 ++++----- lib/inv_ca_cg.cpp | 238 +++++++++++++++++++++++++----------------- 2 files changed, 167 insertions(+), 118 deletions(-) diff --git a/include/invert_quda.h b/include/invert_quda.h index 640c90c5f7..9946021052 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -1247,56 +1247,53 @@ namespace quda { bool lambda_init; QudaCABasis basis; - std::vector Q_AQandg; // Fused inner product matrix - std::vector Q_AS; // inner product matrix - std::vector alpha; // QAQ^{-1} g - std::vector beta; // QAQ^{-1} QpolyS + std::vector> Q_AQandg; // Fused inner product matrix + std::vector> Q_AS; // inner product matrix + std::vector> alpha; // QAQ^{-1} g + std::vector> beta; // QAQ^{-1} QpolyS - ColorSpinorField r; + std::vector r; - std::vector S; // residual vectors - std::vector AS; // mat * residual vectors. Can be replaced by a single temporary. - std::vector Q; // CG direction vectors - std::vector Qtmp; // CG direction vectors for pointer swap - std::vector AQ; // mat * CG direction vectors. - // it's possible to avoid carrying these - // around, but there's a stability penalty, - // and computing QAQ becomes a pain (though - // it does let you fuse the reductions...) + std::vector> S; // residual vectors + std::vector> AS; // mat * residual vectors. Can be replaced by a single temporary. + std::vector> Q; // CG direction vectors + std::vector> Qtmp; // CG direction vectors for pointer swap + std::vector> AQ; // mat * CG direction vectors. + // it's possible to avoid carrying these + // around, but there's a stability penalty, + // and computing QAQ becomes a pain (though + // it does let you fuse the reductions...) /** @brief Initiate the fields needed by the solver @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); /** @brief Compute the alpha coefficients + @param[in] b batch number */ - void compute_alpha(); + void compute_alpha(int b); /** @brief Compute the beta coefficients + @param[in] b batch number */ - void compute_beta(); + void compute_beta(int b); /** - @ brief Check if it's time for a reliable update + @brief Check if it's time for a reliable update */ - int reliable(double &rNorm, double &maxrr, int &rUpdate, const double &r2, const double &delta); + int reliable(double &rNorm, double &maxrr, int &rUpdate, const double &r2, const double &delta); public: CACG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); virtual ~CACG(); - void operator()(cvector_ref &out, cvector_ref &in) override - { - for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); - } - - void operator()(ColorSpinorField &out, const ColorSpinorField &in); + void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual vector from the prior solve diff --git a/lib/inv_ca_cg.cpp b/lib/inv_ca_cg.cpp index c50f18dc31..9afd8c37dd 100644 --- a/lib/inv_ca_cg.cpp +++ b/lib/inv_ca_cg.cpp @@ -27,21 +27,27 @@ namespace quda if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_FREE); } - void CACG::create(ColorSpinorField &x, const ColorSpinorField &b) + void CACG::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { + if (!init || r.size() != b.size()) { if (!param.is_preconditioner) getProfile().TPSTART(QUDA_PROFILE_INIT); - Q_AQandg.resize(param.Nkrylov * (param.Nkrylov + 1)); - Q_AS.resize(param.Nkrylov * param.Nkrylov); - alpha.resize(param.Nkrylov); - beta.resize(param.Nkrylov * param.Nkrylov); + Q_AQandg.resize(b.size()); + Q_AS.resize(b.size()); + alpha.resize(b.size()); + beta.resize(b.size()); + for (auto i = 0u; i < b.size(); i++) { + Q_AQandg[i].resize(param.Nkrylov * (param.Nkrylov + 1)); + Q_AS[i].resize(param.Nkrylov * param.Nkrylov); + alpha[i].resize(param.Nkrylov); + beta[i].resize(param.Nkrylov * param.Nkrylov); + } - ColorSpinorParam csParam(b); + ColorSpinorParam csParam(b[0]); csParam.create = QUDA_NULL_FIELD_CREATE; - if (mixed()) r = ColorSpinorField(csParam); + if (mixed()) resize(r, b.size(), csParam); // now allocate sloppy fields csParam.setPrecision(param.precision_sloppy); @@ -52,15 +58,18 @@ namespace quda Qtmp.resize(param.Nkrylov); // only used as an intermediate for pointer swaps S.resize(param.Nkrylov); for (int i = 0; i < param.Nkrylov; i++) { - AS[i] = ColorSpinorField(csParam); - Q[i] = ColorSpinorField(csParam); - AQ[i] = ColorSpinorField(csParam); - Qtmp[i] = ColorSpinorField(csParam); + resize(AS[i], b.size(), csParam); + resize(Q[i], b.size(), csParam); + resize(AQ[i], b.size(), csParam); + resize(Qtmp[i], b.size(), csParam); // in the power basis we can alias AS[k] to S[k+1] - S[i] = (basis == QUDA_POWER_BASIS && i > 0) ? AS[i - 1] : ColorSpinorField(csParam); + if (basis == QUDA_POWER_BASIS && i > 0) + create_alias(S[i], AS[i - 1]); + else + resize(S[i], b.size(), csParam); } - if (!mixed()) r = S[0].create_alias(csParam); + if (!mixed()) create_alias(r, S[0]); if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_INIT); @@ -88,7 +97,7 @@ namespace quda // psi = svd.solve(phi); } - void CACG::compute_alpha() + void CACG::compute_alpha(int b) { if (!param.is_preconditioner) { getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); @@ -98,31 +107,30 @@ namespace quda const int N = Q.size(); switch (N) { #if 0 // since CA-CG is not used anywhere at the moment, no point paying for this compilation cost - case 1: compute_alpha_N<1>(Q_AQandg, alpha); break; - case 2: compute_alpha_N<2>(Q_AQandg, alpha); break; - case 3: compute_alpha_N<3>(Q_AQandg, alpha); break; - case 4: compute_alpha_N<4>(Q_AQandg, alpha); break; - case 5: compute_alpha_N<5>(Q_AQandg, alpha); break; - case 6: compute_alpha_N<6>(Q_AQandg, alpha); break; - case 7: compute_alpha_N<7>(Q_AQandg, alpha); break; - case 8: compute_alpha_N<8>(Q_AQandg, alpha); break; - case 9: compute_alpha_N<9>(Q_AQandg, alpha); break; - case 10: compute_alpha_N<10>(Q_AQandg, alpha); break; - case 11: compute_alpha_N<11>(Q_AQandg, alpha); break; - case 12: compute_alpha_N<12>(Q_AQandg, alpha); break; + case 1: compute_alpha_N<1>(Q_AQandg[b], alpha[b]); break; + case 2: compute_alpha_N<2>(Q_AQandg[b], alpha[b]); break; + case 3: compute_alpha_N<3>(Q_AQandg[b], alpha[b]); break; + case 4: compute_alpha_N<4>(Q_AQandg[b], alpha[b]); break; + case 5: compute_alpha_N<5>(Q_AQandg[b], alpha[b]); break; + case 6: compute_alpha_N<6>(Q_AQandg[b], alpha[b]); break; + case 7: compute_alpha_N<7>(Q_AQandg[b], alpha[b]); break; + case 8: compute_alpha_N<8>(Q_AQandg[b], alpha[b]); break; + case 9: compute_alpha_N<9>(Q_AQandg[b], alpha[b]); break; + case 10: compute_alpha_N<10>(Q_AQandg[b], alpha[b]); break; + case 11: compute_alpha_N<11>(Q_AQandg[b], alpha[b]); break; + case 12: compute_alpha_N<12>(Q_AQandg[b], alpha[b]); break; #endif default: // failsafe typedef Matrix matrix; typedef Matrix vector; - const int N = Q.size(); matrix matQ_AQ(N, N); vector vecg(N); for (int i = 0; i < N; i++) { - vecg(i) = Q_AQandg[i * (N + 1) + N]; - for (int j = 0; j < N; j++) { matQ_AQ(i, j) = Q_AQandg[i * (N + 1) + j]; } + vecg(i) = Q_AQandg[b][i * (N + 1) + N]; + for (int j = 0; j < N; j++) { matQ_AQ(i, j) = Q_AQandg[b][i * (N + 1) + j]; } } - Map vecalpha(alpha.data(), N); + Map vecalpha(alpha[b].data(), N); vecalpha = matQ_AQ.fullPivLu().solve(vecg); @@ -156,7 +164,7 @@ namespace quda // psi = svd.solve(phi); } - void CACG::compute_beta() + void CACG::compute_beta(int b) { if (!param.is_preconditioner) { getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); @@ -166,28 +174,27 @@ namespace quda const int N = Q.size(); switch (N) { #if 0 // since CA-CG is not used anywhere at the moment, no point paying for this compilation cost - case 1: compute_beta_N<1>(Q_AQandg, Q_AS, beta); break; - case 2: compute_beta_N<2>(Q_AQandg, Q_AS, beta); break; - case 3: compute_beta_N<3>(Q_AQandg, Q_AS, beta); break; - case 4: compute_beta_N<4>(Q_AQandg, Q_AS, beta); break; - case 5: compute_beta_N<5>(Q_AQandg, Q_AS, beta); break; - case 6: compute_beta_N<6>(Q_AQandg, Q_AS, beta); break; - case 7: compute_beta_N<7>(Q_AQandg, Q_AS, beta); break; - case 8: compute_beta_N<8>(Q_AQandg, Q_AS, beta); break; - case 9: compute_beta_N<9>(Q_AQandg, Q_AS, beta); break; - case 10: compute_beta_N<10>(Q_AQandg, Q_AS, beta); break; - case 11: compute_beta_N<11>(Q_AQandg, Q_AS, beta); break; - case 12: compute_beta_N<12>(Q_AQandg, Q_AS, beta); break; + case 1: compute_beta_N<1>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 2: compute_beta_N<2>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 3: compute_beta_N<3>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 4: compute_beta_N<4>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 5: compute_beta_N<5>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 6: compute_beta_N<6>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 7: compute_beta_N<7>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 8: compute_beta_N<8>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 9: compute_beta_N<9>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 10: compute_beta_N<10>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 11: compute_beta_N<11>(Q_AQandg[b], Q_AS[b], beta[b]); break; + case 12: compute_beta_N<12>(Q_AQandg[b], Q_AS[b], beta[b]); break; #endif default: // failsafe typedef Matrix matrix; - const int N = Q.size(); matrix matQ_AQ(N, N); for (int i = 0; i < N; i++) { - for (int j = 0; j < N; j++) { matQ_AQ(i, j) = Q_AQandg[i * (N + 1) + j]; } + for (int j = 0; j < N; j++) { matQ_AQ(i, j) = Q_AQandg[b][i * (N + 1) + j]; } } - Map matQ_AS(Q_AS.data(), N, N), matbeta(beta.data(), N, N); + Map matQ_AS(Q_AS[b].data(), N, N), matbeta(beta[b].data(), N, N); matQ_AQ = -matQ_AQ; matbeta = matQ_AQ.fullPivLu().solve(matQ_AS); @@ -244,7 +251,7 @@ namespace quda 2. Steepest descent minmization of the residual in this basis 3. Update solution and residual vectors */ - void CACG::operator()(ColorSpinorField &x, const ColorSpinorField &b) + void CACG::operator()(cvector_ref &x, cvector_ref &b) { if (param.is_preconditioner) commGlobalReductionPush(param.global_reduction); @@ -261,12 +268,12 @@ namespace quda // compute b2, but only if we need to bool fixed_iteration = param.sloppy_converge && n_krylov == param.maxiter && !param.compute_true_res; - double b2 = !fixed_iteration ? blas::norm2(b) : 1.0; - double r2 = 0.0; // if zero source then we will exit immediately doing no work + auto b2 = !fixed_iteration ? blas::norm2(b) : vector(b.size(), 1.0); + vector r2(b.size(), 0.0); // if zero source then we will exit immediately doing no work if (param.deflate) { // Construct the eigensolver and deflation space. - constructDeflationSpace(b, matEig); + constructDeflationSpace(b[0], matEig); if (deflate_compute) { // compute the deflation space. if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); @@ -296,6 +303,12 @@ namespace quda blas::zero(x); } + // Check to see that we're not trying to invert on a zero-field source + if (is_zero_src(x, b, b2)) { + if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); + return; + } + if (param.deflate && param.maxiter > 1) { // Deflate and add solution to accumulator eig_solve->deflate(x, r, evecs, evals, true); @@ -321,7 +334,7 @@ namespace quda // Perform 100 power iterations, normalizing every 10 mat-vecs, using r as an initial seed // and Q[0]/AQ[0] as temporaries for the power iterations - lambda_max = 1.1 * Solver::performPowerIterations(matSloppy, r, Q[0], AQ[0], 100, 10); + lambda_max = 1.1 * Solver::performPowerIterations(matSloppy, r[0], Q[0][0], AQ[0][0], 100, 10); logQuda(QUDA_SUMMARIZE, "CA-CG Approximate lambda max = 1.1 x %e\n", lambda_max / 1.1); lambda_init = true; @@ -332,20 +345,8 @@ namespace quda } } - // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0) { - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - warningQuda("inverting on zero-field source\n"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; - return; - } else { - b2 = r2; - } - } - - double stop = !fixed_iteration ? stopping(param.tol, b2, param.residual_type) : 0.0; // stopping condition of solver + auto stop = !fixed_iteration ? stopping(param.tol, b2, param.residual_type) : + vector(b2.size(), 0.0); // stopping condition of solver const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; @@ -355,8 +356,11 @@ namespace quda const int maxResIncrease = param.max_res_increase; // check if we reached the limit of our tolerance const int maxResIncreaseTotal = param.max_res_increase_total; - double heavy_quark_res = 0.0; // heavy quark residual - if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + vector heavy_quark_res(b.size(), 0.0); // heavy quark residual + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } int resIncrease = 0; int resIncreaseTotal = 0; @@ -366,13 +370,13 @@ namespace quda getProfile().TPSTART(QUDA_PROFILE_COMPUTE); } int total_iter = 0; - double r2_old = r2; + auto r2_old = r2; bool l2_converge = false; // Various variables related to reliable updates. int rUpdate = 0; // count reliable updates. double delta = param.delta; // delta for reliable updates. - double rNorm = sqrt(r2); // The current residual norm. + double rNorm = sqrt(r2[0]); // The current residual norm. double maxrr = rNorm; // The maximum residual norm since the last reliable update. double maxr_deflate = rNorm; // The maximum residual since the last deflation @@ -380,6 +384,13 @@ namespace quda double m_map = 2. / (lambda_max - lambda_min); double b_map = -(lambda_max + lambda_min) / (lambda_max - lambda_min); + auto get_i = [](std::vector> &p, int i) { + vector_ref p_i; + p_i.reserve(p.size()); + for (auto &pi : p) p_i.push_back(pi[i]); + return p_i; + }; + blas::copy(S[0], r); // no op if uni-precision PrintStats("CA-CG", total_iter, r2, b2, heavy_quark_res); @@ -401,14 +412,28 @@ namespace quda // Compute the beta coefficients for updating Q, AQ // 1. compute matrix Q_AS = -Q^\dagger AS // 2. Solve Q_AQ beta = Q_AS - blas::block::reDotProduct(Q_AS, AQ, S); - compute_beta(); + for (auto i = 0u; i < b.size(); i++) { + auto AQi = get_i(AQ, i); + auto Si = get_i(S, i); + blas::block::reDotProduct(Q_AS[i], AQi, Si); + compute_beta(i); + } // update direction vectors - blas::block::axpyz(beta, Q, S, Qtmp); + for (auto i = 0u; i < b.size(); i++) { + auto Qi = get_i(Q, i); + auto Si = get_i(S, i); + auto Qtmpi = get_i(Qtmp, i); + blas::block::axpyz(beta[i], Qi, Si, Qtmpi); + } for (int i = 0; i < n_krylov; i++) std::swap(Q[i], Qtmp[i]); - blas::block::axpyz(beta, AQ, AS, Qtmp); + for (auto i = 0u; i < b.size(); i++) { + auto AQi = get_i(AQ, i); + auto ASi = get_i(AS, i); + auto Qtmpi = get_i(Qtmp, i); + blas::block::axpyz(beta[i], AQi, ASi, Qtmpi); + } for (int i = 0; i < n_krylov; i++) std::swap(AQ[i], Qtmp[i]); } @@ -416,19 +441,34 @@ namespace quda // 1. Compute Q_AQ = Q^\dagger AQ and g = Q^dagger r = Q^dagger S[0] // 2. Solve Q_AQ alpha = g { - blas::block::reDotProduct(Q_AQandg, Q, {AQ, S[0]}); - compute_alpha(); + for (auto i = 0u; i < b.size(); i++) { + auto Qi = get_i(Q, i); + auto AQi = get_i(AQ, i); + auto Si = get_i(S, i); + blas::block::reDotProduct(Q_AQandg[i], Qi, {AQi, Si[0]}); + compute_alpha(i); + } } // update the solution vector - blas::block::axpy(alpha, Q, x); + for (auto i = 0u; i < b.size(); i++) { + auto Qi = get_i(Q, i); + blas::block::axpy(alpha[i], Qi, x[i]); + } - for (int i = 0; i < param.Nkrylov; i++) { alpha[i] = -alpha[i]; } + for (auto i = 0u; i < b.size(); i++) + for (auto j = 0; j < n_krylov; j++) { alpha[i][j] = -alpha[i][j]; } // Can we fuse these? We don't need this reduce in all cases... - blas::block::axpy(alpha, AQ, S[0]); + for (auto i = 0u; i < b.size(); i++) { + auto AQi = get_i(AQ, i); + auto Si = get_i(S, i); + blas::block::axpy(alpha[i], AQi, Si[0]); + } // if (getVerbosity() >= QUDA_VERBOSE) r2 = blas::norm2(S[0]); - /*else*/ r2 = Q_AQandg[param.Nkrylov]; // actually the old r2... so we do one more iter than needed... + /*else*/ + // actually the old r2... so we do one more iter than needed... + for (auto i = 0u; i < r2.size(); i++) r2[i] = Q_AQandg[i][param.Nkrylov]; } else { // fixed iterations // On the first pass, Q = S; AQ = AQ. We can just skip that. @@ -438,14 +478,21 @@ namespace quda // We do compute the alpha coefficients: this is the same code as above // 1. Compute "Q_AQ" = S^\dagger AS and g = S^dagger r = S^dagger S[0] // 2. Solve "Q_AQ" alpha = g - blas::block::reDotProduct(Q_AQandg, S, {AS, S[0]}); - compute_alpha(); + for (auto i = 0u; i < b.size(); i++) { + auto Si = get_i(S, i); + auto ASi = get_i(AS, i); + blas::block::reDotProduct(Q_AQandg[i], Si, {ASi, Si[0]}); + compute_alpha(i); + } // update the solution vector - blas::block::axpy(alpha, S, x); + for (auto i = 0u; i < b.size(); i++) { + auto Si = get_i(S, i); + blas::block::axpy(alpha[i], Si, x[i]); + } // no need to update AS - r2 = Q_AQandg[param.Nkrylov]; // actually the old r2... so we do one more iter than needed... + for (auto i = 0u; i < r2.size(); i++) r2[i] = Q_AQandg[i][param.Nkrylov]; } // NOTE: Because we always carry around the residual from an iteration before, we @@ -458,13 +505,13 @@ namespace quda // update since n_krylov or maxiter reached, converged or reliable update required // note that the heavy quark residual will by definition only be checked every n_krylov steps // Note: this won't reliable update when the norm _increases_. - if (total_iter >= param.maxiter || (r2 < stop && !l2_converge) || reliable(rNorm, maxrr, rUpdate, r2, delta)) { + if (total_iter >= param.maxiter || (r2 < stop && !l2_converge) || reliable(rNorm, maxrr, rUpdate, r2[0], delta)) { if ((r2 < stop || total_iter >= param.maxiter) && param.sloppy_converge) break; mat(r, x); r2 = blas::xmyNorm(b, r); - if (param.deflate && sqrt(r2) < maxr_deflate * param.tol_restart) { + if (param.deflate && sqrt(r2[0]) < maxr_deflate * param.tol_restart) { // Deflate and add solution to accumulator eig_solve->deflate(x, r, evecs, evals, true); @@ -472,12 +519,15 @@ namespace quda mat(r, x); r2 = blas::xmyNorm(b, r); - maxr_deflate = sqrt(r2); + maxr_deflate = sqrt(r2[0]); } blas::copy(S[0], r); - if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } // break-out check if we have reached the limit of the precision if (r2 > r2_old) { @@ -485,7 +535,7 @@ namespace quda resIncreaseTotal++; warningQuda( "CA-CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", - sqrt(r2), sqrt(r2_old), resIncreaseTotal); + sqrt(r2[0]), sqrt(r2_old[9]), resIncreaseTotal); if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) { warningQuda("CA-CG: solver exiting due to too many true residual norm increases"); break; @@ -507,10 +557,12 @@ namespace quda if (param.compute_true_res) { // Calculate the true residual mat(r, x); - double true_res = blas::xmyNorm(b, r); - param.true_res = sqrt(true_res / b2); - param.true_res_hq - = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? sqrt(blas::HeavyQuarkResidualNorm(x, r).z) : 0.0; + auto true_res = blas::xmyNorm(b, r); + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) { + param.true_res[i] = sqrt(true_res[i] / b2[i]); + param.true_res_hq[i] = sqrt(hq[i].z); + } } if (!param.is_preconditioner) { From a3186a0fd88d400e519cf8f875355ab201eeddc4 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Thu, 12 Sep 2024 12:50:01 -0700 Subject: [PATCH 072/103] BiCGStab is now MRHS --- include/invert_quda.h | 29 ++--- lib/inv_bicgstab_quda.cpp | 232 +++++++++++++++++++------------------- 2 files changed, 127 insertions(+), 134 deletions(-) diff --git a/include/invert_quda.h b/include/invert_quda.h index 9946021052..5ae9e5aa0c 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -897,7 +897,7 @@ namespace quda { public: CG3(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m); - virtual void operator()(cvector_ref &out, cvector_ref &in) override; + void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual vector from the prior solve @@ -981,14 +981,14 @@ namespace quda { private: const DiracMdagM matMdagM; // used by the eigensolver - ColorSpinorField y; // Full precision solution accumulator - ColorSpinorField r; // Full precision residual vector - ColorSpinorField p; // Sloppy precision search direction - ColorSpinorField v; // Sloppy precision A * p - ColorSpinorField t; // Sloppy precision vector used for minres step - ColorSpinorField r0; // Bi-orthogonalization vector - ColorSpinorField r_sloppy; // Slopy precision residual vector - ColorSpinorField x_sloppy; // Sloppy solution accumulator vector + std::vector y; // Full precision solution accumulator + std::vector r; // Full precision residual vector + std::vector p; // Sloppy precision search direction + std::vector v; // Sloppy precision A * p + std::vector t; // Sloppy precision vector used for minres step + std::vector r0; // Bi-orthogonalization vector + std::vector r_sloppy; // Slopy precision residual vector + std::vector x_sloppy; // Sloppy solution accumulator vector bool init = false; /** @@ -996,19 +996,14 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: BiCGstab(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, SolverParam ¶m); virtual ~BiCGstab(); - void operator()(cvector_ref &out, cvector_ref &in) override - { - for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); - } - - void operator()(ColorSpinorField &out, const ColorSpinorField &in); + void operator()(cvector_ref &out, cvector_ref &in) override; /** @return Return the residual vector from the prior solve @@ -1188,8 +1183,6 @@ namespace quda { void operator()(cvector_ref &out, cvector_ref &in) override; - void operator()(ColorSpinorField &out, const ColorSpinorField &in); - /** @return Return the residual vector from the prior solve */ diff --git a/lib/inv_bicgstab_quda.cpp b/lib/inv_bicgstab_quda.cpp index 1c2e8187b5..73e3f9868e 100644 --- a/lib/inv_bicgstab_quda.cpp +++ b/lib/inv_bicgstab_quda.cpp @@ -19,24 +19,45 @@ namespace quda { BiCGstab::~BiCGstab() { destroyDeflationSpace(); } - void BiCGstab::create(ColorSpinorField &x, const ColorSpinorField &b) + void BiCGstab::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { + if (!init || r.size() != b.size()) { getProfile().TPSTART(QUDA_PROFILE_INIT); - ColorSpinorParam csParam(x); + ColorSpinorParam csParam(x[0]); csParam.create = QUDA_ZERO_FIELD_CREATE; - y = ColorSpinorField(csParam); - r = ColorSpinorField(csParam); + resize(y, b.size(), csParam); + resize(r, b.size(), csParam); csParam.setPrecision(param.precision_sloppy); - p = ColorSpinorField(csParam); - v = ColorSpinorField(csParam); - t = ColorSpinorField(csParam); + resize(p, b.size(), csParam); + resize(v, b.size(), csParam); + resize(t, b.size(), csParam); + + if (param.precision_sloppy == x.Precision()) { + create_alias(r_sloppy, r); + if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { + create_alias(r0, b); + } else { + csParam.create = QUDA_NULL_FIELD_CREATE; + resize(r0, b.size(), csParam); + blas::copy(r0, r); + } + } else { + csParam.create = QUDA_NULL_FIELD_CREATE; + resize(r_sloppy, b.size(), csParam); + resize(r0, b.size(), csParam); + } + + if (param.precision_sloppy == x.Precision() || !param.use_sloppy_partial_accumulator) { + create_alias(x_sloppy, x); + } else { + resize(x_sloppy, b.size(), csParam); + } - getProfile().TPSTOP(QUDA_PROFILE_INIT); init = true; + getProfile().TPSTOP(QUDA_PROFILE_INIT); } // init } @@ -56,27 +77,31 @@ namespace quda { //int updateX = (rNorm < delta*r0Norm && r0Norm <= maxrx) ? 1 : 0 int updateR = (rNorm < delta*maxrr) ? 1 : 0; - //printf("reliable %d %e %e %e %e\n", updateR, rNorm, maxrx, maxrr, r2); - return updateR; } - void BiCGstab::operator()(ColorSpinorField &x, const ColorSpinorField &b) + void BiCGstab::operator()(cvector_ref &x, cvector_ref &b) { create(x, b); getProfile().TPSTART(QUDA_PROFILE_INIT); - double b2 = blas::norm2(b); // norm sq of source - double r2 = 0.0; // norm sq of residual + auto b2 = blas::norm2(b); // norm sq of source + vector r2(b.size(), 0.0); // norm sq of residual + + // Check to see that we're not trying to invert on a zero-field source + if (is_zero_src(x, b, b2)) { + getProfile().TPSTOP(QUDA_PROFILE_INIT); + return; + } if (param.deflate) { // Construct the eigensolver and deflation space if requested. if (param.eig_param.eig_type == QUDA_EIG_TR_LANCZOS || param.eig_param.eig_type == QUDA_EIG_BLK_TR_LANCZOS) { - constructDeflationSpace(b, matMdagM); + constructDeflationSpace(b[0], matMdagM); } else { // Use Arnoldi to inspect the space only and turn off deflation - constructDeflationSpace(b, mat); + constructDeflationSpace(b[0], mat); param.deflate = false; } if (deflate_compute) { @@ -101,7 +126,7 @@ namespace quda { if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { mat(r, x); r2 = blas::xmyNorm(b, r); - blas::copy(y, x); + for (auto i = 0u; i < x.size(); i++) std::swap(y[i], x[i]); } else { blas::copy(r, b); r2 = b2; @@ -117,62 +142,23 @@ namespace quda { r2 = blas::xmyNorm(b, r); } - // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0) { - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - warningQuda("inverting on zero-field source"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; - getProfile().TPSTOP(QUDA_PROFILE_INIT); - return; - } else if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { - b2 = r2; - } else { - errorQuda("Null vector computing requires non-zero guess!"); - } - } - - // set field aliasing according to whether we are doing mixed precision or not - if (param.precision_sloppy == x.Precision()) { - r_sloppy = r.create_alias(); - - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - r0 = const_cast(b).create_alias(); - } else { - ColorSpinorParam csParam(r); - csParam.create = QUDA_NULL_FIELD_CREATE; - r0 = ColorSpinorField(csParam); - blas::copy(r0, r); - } - } else { - ColorSpinorParam csParam(x); - csParam.setPrecision(param.precision_sloppy); - csParam.create = QUDA_NULL_FIELD_CREATE; - r_sloppy = ColorSpinorField(csParam); + if (param.precision != param.precision_sloppy) { blas::copy(r_sloppy, r); - r0 = ColorSpinorField(csParam); - blas::copy(r0, r); - } - - if (param.precision_sloppy == x.Precision() || !param.use_sloppy_partial_accumulator) { - x_sloppy = x.create_alias(); - blas::zero(x_sloppy); - } else { - ColorSpinorParam csParam(x); - csParam.create = QUDA_ZERO_FIELD_CREATE; - csParam.setPrecision(param.precision_sloppy); - x_sloppy = ColorSpinorField(csParam); + blas::copy(r0, param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO ? b : r); } getProfile().TPSTOP(QUDA_PROFILE_INIT); getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); - double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; - double heavy_quark_res = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x,r).z) : 0.0; + vector heavy_quark_res(b.size(), 0.0); + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual double delta = param.delta; @@ -180,16 +166,15 @@ namespace quda { int k = 0; int rUpdate = 0; - Complex rho(1.0, 0.0); - Complex rho0 = rho; - Complex alpha(1.0, 0.0); - Complex omega(1.0, 0.0); - Complex beta; + vector rho(b.size(), {1.0, 0.0}); + vector rho0 = rho; + vector alpha(b.size(), {1.0, 0.0}); + vector omega(b.size(), {1.0, 0.0}); + vector beta(b.size()); - double3 rho_r2; - double3 omega_t2; + vector rho_r2(b.size()); - double rNorm = sqrt(r2); + double rNorm = sqrt(r2[0]); //double r0Norm = rNorm; double maxrr = rNorm; double maxrx = rNorm; @@ -204,9 +189,6 @@ namespace quda { bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); - logQuda(QUDA_DEBUG_VERBOSE, "BiCGstab debug: x2=%e, r2=%e, v2=%e, p2=%e, r0=%e, t2=%e\n", blas::norm2(x), - blas::norm2(r_sloppy), blas::norm2(v), blas::norm2(p), blas::norm2(r0), blas::norm2(t)); - // track if we just performed an exact recalculation of y, r, r2 bool just_updated = false; @@ -215,15 +197,19 @@ namespace quda { matSloppy(v, p); - Complex r0v; + vector r0v; if (param.pipeline) { r0v = blas::cDotProduct(r0, v); if (k > 0) rho = blas::cDotProduct(r0, r); } else { r0v = blas::cDotProduct(r0, v); } - if (abs(rho) == 0.0) alpha = 0.0; - else alpha = rho / r0v; + for (auto i = 0u; i < b.size(); i++) { + if (abs(rho[i]) == 0.0) + alpha[i] = 0.0; + else + alpha[i] = rho[i] / r0v[i]; + } // r -= alpha*v blas::caxpy(-alpha, v, r_sloppy); @@ -233,55 +219,63 @@ namespace quda { int updateR = 0; if (param.pipeline) { // omega = (t, r) / (t, t) - omega_t2 = blas::cDotProductNormA(t, r_sloppy); - Complex tr = Complex(omega_t2.x, omega_t2.y); - double t2 = omega_t2.z; - omega = tr / t2; - double s2 = blas::norm2(r_sloppy); - Complex r0t = blas::cDotProduct(r0, t); - beta = -r0t / r0v; - r2 = s2 - real(omega * conj(tr)); + auto omega_t2_s2 = blas::cDotProductNormAB(t, r_sloppy); + auto r0t = blas::cDotProduct(r0, t); + + for (auto i = 0u; i < b.size(); i++) { + omega[i] = Complex {omega_t2_s2[i].x, omega_t2_s2[i].y} / omega_t2_s2[i].z; + beta[i] = -r0t[i] / r0v[i]; + r2[i] = omega_t2_s2[i].w - real(omega[i] * conj(Complex {omega_t2_s2[i].x, omega_t2_s2[i].y})); + } // now we can work out if we need to do a reliable update - updateR = reliable(rNorm, maxrx, maxrr, r2, delta); + updateR = reliable(rNorm, maxrx, maxrr, r2[0], delta); } else { // omega = (t, r) / (t, t) - omega_t2 = blas::cDotProductNormA(t, r_sloppy); - omega = Complex(omega_t2.x / omega_t2.z, omega_t2.y / omega_t2.z); + auto omega_t2 = blas::cDotProductNormA(t, r_sloppy); + for (auto i = 0u; i < b.size(); i++) + omega[i] = Complex(omega_t2[i].x / omega_t2[i].z, omega_t2[i].y / omega_t2[i].z); } if (param.pipeline && !updateR) { // x += alpha*p + omega*r, r -= omega*t, p = r - beta*omega*v + beta*p blas::caxpbypzYmbw(alpha, p, omega, r_sloppy, x_sloppy, t); - blas::cxpaypbz(r_sloppy, -beta * omega, v, beta, p); + vector beta_omega(b.size()); + for (auto i = 0u; i < b.size(); i++) beta_omega[i] = -beta[i] * omega[i]; + blas::cxpaypbz(r_sloppy, beta_omega, v, beta, p); // tripleBiCGstabUpdate(alpha, p, omega, r_sloppy, x_sloppy, t, -beta*omega, v, beta, p } else { // x += alpha*p + omega*r, r -= omega*t, r2 = (r,r), rho = (r0, r) rho_r2 = blas::caxpbypzYmbwcDotProductUYNormY(alpha, p, omega, r_sloppy, x_sloppy, t, r0); rho0 = rho; - rho = Complex(rho_r2.x, rho_r2.y); - r2 = rho_r2.z; + for (auto i = 0u; i < b.size(); i++) { + rho[i] = Complex(rho_r2[i].x, rho_r2[i].y); + r2[i] = rho_r2[i].z; + } } if (use_heavy_quark_res && k % heavy_quark_check == 0) { - if (&x != &x_sloppy) { - heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x_sloppy, r_sloppy).z); + vector hq; + + if (x.Precision() != x_sloppy[0].Precision()) { + hq = blas::HeavyQuarkResidualNorm(x_sloppy, r_sloppy); } else { blas::copy(r, r_sloppy); - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z); + hq = blas::xpyHeavyQuarkResidualNorm(x, y, r); } + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); } - if (!param.pipeline) updateR = reliable(rNorm, maxrx, maxrr, r2, delta); + if (!param.pipeline) updateR = reliable(rNorm, maxrx, maxrr, r2[0], delta); if (updateR) { - if (x.Precision() != x_sloppy.Precision()) blas::copy(x, x_sloppy); + if (x.Precision() != x_sloppy[0].Precision()) blas::copy(x, x_sloppy); blas::xpy(x, y); mat(r, y); r2 = blas::xmyNorm(b, r); - if (param.deflate && sqrt(r2) < param.tol_restart) { + if (param.deflate && sqrt(r2[0]) < param.tol_restart) { // Deflate and accumulate to solution vector eig_solve->deflate(y, r, evecs, evals, true); @@ -290,10 +284,10 @@ namespace quda { r2 = blas::xmyNorm(b, r); } - if (x.Precision() != r_sloppy.Precision()) blas::copy(r_sloppy, r); + if (param.precision != param.precision_sloppy) blas::copy(r_sloppy, r); blas::zero(x_sloppy); - rNorm = sqrt(r2); + rNorm = sqrt(r2[0]); maxrr = rNorm; maxrx = rNorm; // r0Norm = rNorm; @@ -305,20 +299,17 @@ namespace quda { k++; PrintStats("BiCGstab", k, r2, b2, heavy_quark_res); - logQuda(QUDA_DEBUG_VERBOSE, "BiCGstab debug: x2=%e, r2=%e, v2=%e, p2=%e, r0=%e, t2=%e\n", blas::norm2(x), - blas::norm2(r_sloppy), blas::norm2(v), blas::norm2(p), blas::norm2(r0), blas::norm2(t)); - converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); if (converged) { // make sure we've truly converged if (!just_updated) { - if (x.Precision() != x_sloppy.Precision()) blas::copy(x, x_sloppy); + if (x.Precision() != x_sloppy[0].Precision()) blas::copy(x, x_sloppy); blas::xpy(x, y); mat(r, y); r2 = blas::xmyNorm(b, r); - if (param.deflate && sqrt(r2) < param.tol_restart) { + if (param.deflate && sqrt(r2[0]) < param.tol_restart) { // Deflate and accumulate to solution vector eig_solve->deflate(y, r, evecs, evals, true); // Compute r_defl = RHS - A * LHS @@ -326,10 +317,10 @@ namespace quda { r2 = blas::xmyNorm(b, r); } - if (x.Precision() != r_sloppy.Precision()) blas::copy(r_sloppy, r); + if (r[0].Precision() != r_sloppy[0].Precision()) blas::copy(r_sloppy, r); blas::zero(x_sloppy); - rNorm = sqrt(r2); + rNorm = sqrt(r2[0]); maxrr = rNorm; maxrx = rNorm; // r0Norm = rNorm; @@ -339,7 +330,10 @@ namespace quda { } // explicitly compute the HQ residual if need be - heavy_quark_res = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(y, r).z) : 0.0; + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(y, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res = sqrt(hq[i].z); + } // Update convergence check converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); @@ -347,11 +341,15 @@ namespace quda { // update p if ((!param.pipeline || updateR) && !converged) { // need to update if not pipeline or did a reliable update - if (abs(rho * alpha) == 0.0) - beta = 0.0; - else - beta = (rho / rho0) * (alpha / omega); - blas::cxpaypbz(r_sloppy, -beta * omega, v, beta, p); + vector beta_omega(b.size()); + for (auto i = 0u; i < b.size(); i++) { + if (abs(rho[i] * alpha[i]) == 0.0) + beta[i] = 0.0; + else + beta[i] = (rho[i] / rho0[i]) * (alpha[i] / omega[i]); + beta_omega[i] = -beta[i] * omega[i]; + } + blas::cxpaypbz(r_sloppy, beta_omega, v, beta, p); } } @@ -371,9 +369,11 @@ namespace quda { if (!param.is_preconditioner) { // do not do the below if we this is an inner solver // r2 was freshly computed - param.true_res = sqrt(r2 / b2); - param.true_res_hq = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x,r).z) : 0.0; - + auto hq = use_heavy_quark_res ? blas::HeavyQuarkResidualNorm(x, r) : vector(b.size(), {}); + for (auto i = 0u; i < b.size(); i++) { + param.true_res[i] = sqrt(r2[i] / b2[i]); + param.true_res_hq[i] = sqrt(hq[i].z); + } PrintSummary("BiCGstab", k, r2, b2, stop, param.tol_hq); } From cbf5943397df5f330fa6cf7e0f4d1a552d5b17c4 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 16 Sep 2024 12:25:13 -0700 Subject: [PATCH 073/103] BiCGStab(l) is now multi-RHS --- include/invert_quda.h | 76 ++++++----- lib/inv_bicgstabl_quda.cpp | 258 +++++++++++++++++++++---------------- 2 files changed, 187 insertions(+), 147 deletions(-) diff --git a/include/invert_quda.h b/include/invert_quda.h index 5ae9e5aa0c..1acb9c67ef 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -1021,7 +1021,6 @@ namespace quda { */ class BiCGstabL : public Solver { - private: const DiracMdagM matMdagM; // used by the eigensolver /** @@ -1031,41 +1030,46 @@ namespace quda { int pipeline; // pipelining factor for legacyGramSchmidt // Various coefficients and params needed on each iteration. - Complex rho0, rho1, alpha, omega, beta; // Various coefficients for the BiCG part of BiCGstab-L. - std::vector gamma, gamma_prime, gamma_prime_prime; // Parameters for MR part of BiCGstab-L. (L+1) length. - std::vector tau; // Parameters for MR part of BiCGstab-L. Tech. modified Gram-Schmidt coeffs. (L+1)x(L+1) length. - std::vector sigma; // Parameters for MR part of BiCGstab-L. Tech. the normalization part of Gram-Scmidt. (L+1) length. + vector rho0, rho1, alpha, omega, beta; // Various coefficients for the BiCG part of BiCGstab-L. + vector> gamma, gamma_prime, gamma_prime_prime; // Parameters for MR part of BiCGstab-L. (L+1) length. + vector> tau; // Parameters for MR part of BiCGstab-L. Tech. modified Gram-Schmidt coeffs. (L+1)x(L+1) length. + vector> + sigma; // Parameters for MR part of BiCGstab-L. Tech. the normalization part of Gram-Scmidt. (L+1) length. - ColorSpinorField r_full; //! Full precision residual. - ColorSpinorField y; //! Full precision temporary. + std::vector r_full; //! Full precision residual. + std::vector y; //! Full precision temporary. // sloppy precision fields - ColorSpinorField temp; //! Sloppy temporary vector. - std::vector r; // Current residual + intermediate residual values, along the MR. - std::vector u; // Search directions. + std::vector temp; //! Sloppy temporary vector. + std::vector> r; // Current residual + intermediate residual values, along the MR. + std::vector> u; // Search directions. - ColorSpinorField x_sloppy; //! Sloppy solution vector. - ColorSpinorField r0; //! Shadow residual, in BiCG language. + std::vector x_sloppy; //! Sloppy solution vector. + std::vector r0; //! Shadow residual, in BiCG language. /** @brief Allocate persistent fields and parameter checking - @param[in] x Solution vector - @param[in] b Source vector + @param[in] x Solution vector set + @param[in] b Source vector set */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); /** - @brief Internal routine for reliable updates. Made to not conflict with BiCGstab's implementation. + @brief Internal routine for reliable updates. Made to not conflict with BiCGstab's implementation. */ int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta); /** * @brief Internal routine for performing the MR part of BiCGstab-L * - * @param x_sloppy [out] sloppy accumulator for x - * @param fixed_iteration [in] whether or not this is for a fixed iteration solver + * @param[in,out] x_sloppy sloppy accumulator for x + * @param[in,out] u search directions + * @param[in,out] residual vectors + * @param[in] fixed_iteration whether or not this is for a fixed iteration solver + * @param[in] src_idx which src we are presently working on */ - void computeMR(ColorSpinorField &x_sloppy, bool fixed_iteration); + void computeMR(ColorSpinorField &x_sloppy, cvector_ref &u, cvector_ref &r, + bool fixed_iteration, int src_idx); /** Legacy routines that encapsulate the original pipelined Gram-Schmit. @@ -1079,29 +1083,36 @@ namespace quda { * @brief Internal routine that comptues the "tau" matrix as described in * the original BiCGstab-L paper, supporting pipelining * - * @param begin [in] begin offset for pipelining - * @param size [in] length of pipelining - * @param j [in] row of tau being computed + * @param[in] begin begin offset for pipelining + * @param[in] size length of pipelining + * @param[in] j row of tau being computed + * @param[in] src_idx which src we are presently working on */ - void computeTau(int begin, int size, int j); + void computeTau(int begin, int size, int j, cvector_ref &r, int src_idx); /** * @brief Internal routine that updates R as described in * the original BiCGstab-L paper, supporting pipelining. * - * @param begin [in] begin offset for pipelining - * @param size [in] length of pipelining - * @param j [in] row of tau being computed + * @param[in] begin begin offset for pipelining + * @param[in] size length of pipelining + * @param[in] j row of tau being computed + * @param[in,out] r Residual vector set + * @param[in] src_idx which src we are presently working on */ - void updateR(int begin, int size, int j); + void updateR(int begin, int size, int j, cvector_ref &r, int src_idx); /** * @brief Internal legacy routine for performing the MR part of BiCGstab-L * which more closely matches the paper * - * @param x_sloppy [out] sloppy accumulator for x + * @param[in,out] x_sloppy sloppy accumulator for x + * @param[in,out] u Direction vector set + * @param[in,out] r Residual vector set + * @param[in] src_idx which src we are presently working on */ - void legacyComputeMR(ColorSpinorField &x_sloppy); + void legacyComputeMR(ColorSpinorField &x_sloppy, cvector_ref &u, cvector_ref &r, + int src_idx); /** Solver uses lazy allocation: this flag determines whether we have allocated or not. @@ -1114,12 +1125,7 @@ namespace quda { BiCGstabL(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matEig, SolverParam ¶m); virtual ~BiCGstabL(); - void operator()(cvector_ref &out, cvector_ref &in) override - { - for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]); - } - - void operator()(ColorSpinorField &out, const ColorSpinorField &in); + void operator()(cvector_ref &out, cvector_ref &in) override; virtual bool hermitian() const override { return false; } /** BiCGStab is for any linear system */ diff --git a/lib/inv_bicgstabl_quda.cpp b/lib/inv_bicgstabl_quda.cpp index 9d02e87ca6..42f5b50fbb 100644 --- a/lib/inv_bicgstabl_quda.cpp +++ b/lib/inv_bicgstabl_quda.cpp @@ -22,9 +22,17 @@ namespace quda { + auto get_i = [](std::vector> &p, int i) { + vector_ref p_i; + p_i.reserve(p.size()); + for (auto &pi : p) p_i.push_back(pi[i]); + return p_i; + }; + #ifndef LEGACY_MR // Compute the MR portion of BiCGstab-L - void BiCGstabL::computeMR(ColorSpinorField &x_sloppy, bool fixed_iteration) + void BiCGstabL::computeMR(ColorSpinorField &x_sloppy, cvector_ref &u, + cvector_ref &r, bool fixed_iteration, int src_idx) { using matrix = Matrix, Dynamic, Dynamic, RowMajor>; using vector = Matrix, Dynamic, 1>; @@ -64,12 +72,11 @@ namespace quda { } // Update omega for the next BiCG iteration - omega = gamma(n_krylov - 1); + omega[src_idx] = gamma(n_krylov - 1); std::vector gamma_(n_krylov); if (!fixed_iteration) { - // update u { // u = u[0] - \sum_{j=1}^L \gamma_L u_L @@ -100,39 +107,46 @@ namespace quda { } } - void BiCGstabL::computeTau(int, int, int) + void BiCGstabL::computeTau(int, int, int, cvector_ref &, int) { errorQuda("Legacy MR path in BiCGstab-L called with a non-legacy compile"); } - void BiCGstabL::updateR(int, int, int) { errorQuda("Legacy MR path in BiCGstab-L called with a non-legacy compile"); } + void BiCGstabL::updateR(int, int, int, cvector_ref &, int) + { + errorQuda("Legacy MR path in BiCGstab-L called with a non-legacy compile"); + } - void BiCGstabL::legacyComputeMR(ColorSpinorField &) + void BiCGstabL::legacyComputeMR(ColorSpinorField &, cvector_ref &, cvector_ref &, + int) { errorQuda("Legacy MR path in BiCGstab-L called with a non-legacy compile"); } #else - void BiCGstabL::computeMR(ColorSpinorField &, bool) + void BiCGstabL::computeMR(ColorSpinorField &, cvector_ref &, cvector_ref &, bool, + int) { errorQuda("Non-legacy MR path in BiCGstab-L called with a legacy compile"); } // Utility functions for Gram-Schmidt. Based on GCR functions. // Big change is we need to go from 1 to n_krylov, not 0 to n_krylov-1. - void BiCGstabL::computeTau(int begin, int size, int j) + void BiCGstabL::computeTau(int begin, int size, int j, cvector_ref &r, int src_idx) { std::vector Tau(size); blas::block::cDotProduct(Tau, {r.begin() + begin, r.begin() + begin + size}, r[j]); // vectorized dot product - for (int k = 0; k < size; k++) { tau[(begin + k) * (n_krylov + 1) + j] = Tau[k] / sigma[begin + k]; } + for (int k = 0; k < size; k++) { + tau[src_idx][(begin + k) * (n_krylov + 1) + j] = Tau[k] / sigma[src_idx][begin + k]; + } } - void BiCGstabL::updateR(int begin, int size, int j) + void BiCGstabL::updateR(int begin, int size, int j, cvector_ref &r, int src_idx) { std::vector tau_(size); - for (int i = 0; i < size; i++) { tau_[i] = -tau[(i + begin) * (n_krylov + 1) + j]; } + for (int i = 0; i < size; i++) { tau_[i] = -tau[src_idx][(i + begin) * (n_krylov + 1) + j]; } auto r_ = {r.begin() + begin, r.begin() + begin + size}; auto rj = {r.begin() + j, r.begin() + j + 1}; @@ -144,7 +158,8 @@ namespace quda { Legacy routine for the original pipelined Gram-Schmit See "The MR part" in https://etna.math.kent.edu/vol.1.1993/pp11-32.dir/pp11-32.pdf */ - void BiCGstabL::legacyComputeMR(ColorSpinorField &x_sloppy) + void BiCGstabL::legacyComputeMR(ColorSpinorField &x_sloppy, cvector_ref &u, + cvector_ref &r, int src_idx) { // MR part. Really just modified Gram-Schmidt. // The algorithm uses the byproducts of the Gram-Schmidt to update x @@ -157,8 +172,8 @@ namespace quda { case 0: // no kernel fusion for (int i = 1; i < j; i++) // 5 (j-2) memory transactions here. Start at 1 b/c bicgstabl convention. { - tau[i * (n_krylov + 1) + j] = blas::cDotProduct(r[i], r[j]) / sigma[i]; - blas::caxpy(-tau[i * (n_krylov + 1) + j], r[i], r[j]); + tau[src_idx][i * (n_krylov + 1) + j] = blas::cDotProduct(r[i], r[j]) / sigma[src_idx][i]; + blas::caxpy(-tau[src_idx][i * (n_krylov + 1) + j], r[i], r[j]); } break; case 1: // basic kernel fusion @@ -166,13 +181,13 @@ namespace quda { { break; } - tau[1 * (n_krylov + 1) + j] = blas::cDotProduct(r[1], r[j]) / sigma[1]; + tau[src_idx][1 * (n_krylov + 1) + j] = blas::cDotProduct(r[1], r[j]) / sigma[src_idx][1]; for (int i = 1; i < j - 1; i++) // 4 (j-2) memory transactions here. start at 1. { - tau[(i + 1) * (n_krylov + 1) + j] - = blas::caxpyDotzy(-tau[i * (n_krylov + 1) + j], r[i], r[j], r[i + 1]) / sigma[i + 1]; + auto dot = blas::caxpyDotzy(-tau[src_idx][i * (n_krylov + 1) + j], r[i], r[j], r[i + 1]); + tau[src_idx][(i + 1) * (n_krylov + 1) + j] = dot / sigma[src_idx][i + 1]; } - blas::caxpy(-tau[(j - 1) * (n_krylov + 1) + j], r[j - 1], r[j]); + blas::caxpy(-tau[src_idx][(j - 1) * (n_krylov + 1) + j], r[j - 1], r[j]); break; default: { const int N = pipeline; @@ -183,15 +198,15 @@ namespace quda { // (j-1)/N updates of length N, at 1,1+N,1+2*N,... int step; for (step = 0; step < (j - 1) / N; step++) { - computeTau(1 + step * N, N, j); - updateR(1 + step * N, N, j); + computeTau(1 + step * N, N, j, r, src_idx); + updateR(1 + step * N, N, j, r, src_idx); } if ((j - 1) % N != 0) // need to update the remainder { // 1 update of length (j-1)%N. - computeTau(1 + step * N, (j - 1) % N, j); - updateR(1 + step * N, (j - 1) % N, j); + computeTau(1 + step * N, (j - 1) % N, j, r, src_idx); + updateR(1 + step * N, (j - 1) % N, j, r, src_idx); } } break; } @@ -199,27 +214,30 @@ namespace quda { // sigma_j = r_j^2, gamma'_j = /sigma_j // rjr.x = Re(), rjr.z = - double3 rjr = blas::cDotProductNormA(r[j], r[0]); - sigma[j] = rjr.z; - gamma_prime[j] = Complex(rjr.x, rjr.y) / sigma[j]; + auto rjr = blas::cDotProductNormA(r[j], r[0]); + sigma[src_idx][j] = rjr.z; + gamma_prime[src_idx][j] = Complex(rjr.x, rjr.y) / sigma[src_idx][j]; } // gamma[n_krylov] = gamma'[n_krylov], omega = gamma[n_krylov] - gamma[n_krylov] = gamma_prime[n_krylov]; - omega = gamma[n_krylov]; + gamma[src_idx][n_krylov] = gamma_prime[src_idx][n_krylov]; + omega[src_idx] = gamma[src_idx][n_krylov]; // gamma = T^(-1) gamma_prime. It's in the paper, I promise. for (int j = n_krylov - 1; j > 0; j--) { // Internal def: gamma[j] = gamma'_j - \sum_{i = j+1 to n_krylov} tau_ji gamma_i - gamma[j] = gamma_prime[j]; - for (int i = j + 1; i <= n_krylov; i++) { gamma[j] = gamma[j] - tau[j * (n_krylov + 1) + i] * gamma[i]; } + gamma[src_idx][j] = gamma_prime[src_idx][j]; + for (int i = j + 1; i <= n_krylov; i++) { + gamma[src_idx][j] = gamma[src_idx][j] - tau[src_idx][j * (n_krylov + 1) + i] * gamma[src_idx][i]; + } } // gamma'' = T S gamma. Check paper for defn of S. for (int j = 1; j < n_krylov; j++) { - gamma_prime_prime[j] = gamma[j + 1]; + gamma_prime_prime[src_idx][j] = gamma[src_idx][j + 1]; for (int i = j + 1; i < n_krylov; i++) { - gamma_prime_prime[j] = gamma_prime_prime[j] + tau[j * (n_krylov + 1) + i] * gamma[i + 1]; + gamma_prime_prime[src_idx][j] + = gamma_prime_prime[src_idx][j] + tau[src_idx][j * (n_krylov + 1) + i] * gamma[src_idx][i + 1]; } } @@ -229,7 +247,7 @@ namespace quda { // Update U { std::vector gamma_(n_krylov); - for (int i = 0; i < n_krylov; i++) { gamma_[i] = -gamma[i + 1]; } + for (int i = 0; i < n_krylov; i++) { gamma_[i] = -gamma[src_idx][i + 1]; } blas::block::caxpy(gamma_, {u.begin() + 1, u.end()}, u[0]); } @@ -241,15 +259,15 @@ namespace quda { // the full precision, this can be a killer. std::vector gamma_prime_prime_(n_krylov + 1); std::vector gamma_prime_(n_krylov + 1); - gamma_prime_prime_[0] = gamma[1]; + gamma_prime_prime_[0] = gamma[src_idx][1]; gamma_prime_prime_[n_krylov] = 0.0; // x never gets updated with r[n_krylov] gamma_prime_[0] = 0.0; // r[0] never gets updated with r[0]... obvs. - gamma_prime_[n_krylov] = -gamma_prime[n_krylov]; + gamma_prime_[n_krylov] = -gamma_prime[src_idx][n_krylov]; for (int i = 1; i < n_krylov; i++) { - gamma_prime_prime_[i] = gamma_prime_prime[i]; - gamma_prime_[i] = -gamma_prime[i]; + gamma_prime_prime_[i] = gamma_prime_prime[src_idx][i]; + gamma_prime_[i] = -gamma_prime[src_idx][i]; } - blas::caxpyBxpz(gamma_prime_prime_, r, x_sloppy, gamma_prime_, r[0]); + blas::block::caxpyBxpz(gamma_prime_prime_, r, x_sloppy, gamma_prime_, r[0]); } } @@ -273,12 +291,12 @@ namespace quda { class BiCGstabLUpdate : public Worker { - ColorSpinorField &x; - std::vector &r; - std::vector &u; + std::vector &x; + std::vector> &r; + std::vector> &u; - Complex α - Complex β + std::vector α + std::vector β BiCGstabLUpdateType update_type; @@ -296,13 +314,13 @@ namespace quda { int n_update; public: - BiCGstabLUpdate(ColorSpinorField &x, std::vector &r, std::vector &u, - Complex &alpha, Complex &beta, BiCGstabLUpdateType update_type, int j_max, int n_update) : + BiCGstabLUpdate(std::vector &x, std::vector> &r, + std::vector> &u, std::vector &alpha, + std::vector &beta, BiCGstabLUpdateType update_type, int j_max, int n_update) : x(x), r(r), u(u), alpha(alpha), beta(beta), update_type(update_type), j_max(j_max), n_update(n_update) { } - virtual ~BiCGstabLUpdate() { } void update_j_max(int new_j_max) { j_max = new_j_max; } void update_update_type(BiCGstabLUpdateType new_update_type) { update_type = new_update_type; } @@ -317,7 +335,7 @@ namespace quda { if (update_type == BICGSTABL_UPDATE_U) { for (int i = (count * j_max) / n_update; i < ((count + 1) * j_max) / n_update && i < j_max; i++) { - blas::caxpby(1.0, r[i], -beta, u[i]); + for (auto j = 0u; j < beta.size(); j++) blas::caxpby(1.0, r[i][j], -beta[j], u[i][j]); } } else // (update_type == BICGSTABL_UPDATE_R) @@ -325,9 +343,8 @@ namespace quda { if (count == 0) { blas::caxpy(alpha, u[0], x); } if (j_max > 0) { - for (int i= (count*j_max)/n_update; i<((count+1)*j_max)/n_update && i &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { + if (!init || y.size() != b.size()) { getProfile().TPSTART(QUDA_PROFILE_INIT); // Initialize fields. - ColorSpinorParam csParam(x); + ColorSpinorParam csParam(x[0]); csParam.create = QUDA_ZERO_FIELD_CREATE; // Full precision variables. - r_full = ColorSpinorField(csParam); + resize(r_full, b.size(), csParam); // Create temporary. - y = ColorSpinorField(csParam); + resize(y, b.size(), csParam); // Sloppy precision variables. csParam.setPrecision(param.precision_sloppy); // Sloppy solution. if (!mixed() || !param.use_sloppy_partial_accumulator) { - x_sloppy = x.create_alias(); // x_sloppy and x point to the same vector in memory. + create_alias(x_sloppy, x); // x_sloppy and x point to the same vector in memory. } else { - x_sloppy = ColorSpinorField(csParam); + resize(x_sloppy, b.size(), csParam); } // Shadow residual. if (!mixed() && param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - r0 = const_cast(b).create_alias(); + create_alias(r0, b); } else { - r0 = ColorSpinorField(csParam); + resize(r0, b.size(), csParam); } // Temporary - temp = ColorSpinorField(csParam); + resize(temp, b.size(), csParam); // Residual (+ extra residuals for BiCG steps), Search directions. // Remark: search directions are sloppy in GCR. I wonder if we can // get away with that here. + r.resize(n_krylov + 1); + u.resize(n_krylov + 1); for (int i = 0; i <= n_krylov; i++) { - r[i] = (i > 0 || mixed()) ? ColorSpinorField(csParam) : r_full.create_alias(); - u[i] = ColorSpinorField(csParam); + if (i > 0 || mixed()) + resize(r[i], b.size(), csParam); + else + create_alias(r[i], r_full); + resize(u[i], b.size(), csParam); + } + + alpha.resize(b.size(), 0.0); + beta.resize(b.size()); + omega.resize(b.size(), 1.0); + rho0.resize(b.size(), 1.0); + rho1.resize(b.size()); + + gamma.resize(b.size()); + gamma_prime.resize(b.size()); + gamma_prime_prime.resize(b.size()); + sigma.resize(b.size()); + tau.resize(b.size()); + for (auto i = 0u; i < b.size(); i++) { + gamma[i].resize(n_krylov + 1); + gamma_prime[i].resize(n_krylov + 1); + gamma_prime_prime[i].resize(n_krylov + 1); + sigma[i].resize(n_krylov + 1); + tau[i].resize((n_krylov + 1) * (n_krylov + 1)); } getProfile().TPSTOP(QUDA_PROFILE_INIT); @@ -434,7 +465,7 @@ namespace quda { } } - void BiCGstabL::operator()(ColorSpinorField &x, const ColorSpinorField &b) + void BiCGstabL::operator()(cvector_ref &x, cvector_ref &b) { // BiCGstab-l is based on the algorithm outlined in // BICGSTAB(L) FOR LINEAR EQUATIONS INVOLVING UNSYMMETRIC MATRICES WITH COMPLEX SPECTRUM @@ -449,16 +480,16 @@ namespace quda { // compute b2, but only if we need to bool fixed_iteration = param.sloppy_converge && n_krylov == param.maxiter && !param.compute_true_res; - double b2 = !fixed_iteration ? blas::norm2(b) : 1.0; // norm sq of source. - double r2; // norm sq of residual + auto b2 = !fixed_iteration ? blas::norm2(b) : vector(b.size(), 1.0); // norm sq of source. + vector r2(b.size()); // norm sq of residual if (param.deflate) { // Construct the eigensolver and deflation space if requested. if (param.eig_param.eig_type == QUDA_EIG_TR_LANCZOS || param.eig_param.eig_type == QUDA_EIG_BLK_TR_LANCZOS) { - constructDeflationSpace(b, matMdagM); + constructDeflationSpace(b[0], matMdagM); } else { // Use Arnoldi to inspect the space only and turn off deflation - constructDeflationSpace(b, mat); + constructDeflationSpace(b[0], mat); param.deflate = false; } if (deflate_compute) { @@ -514,19 +545,9 @@ namespace quda { } // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0) { - if (param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { - warningQuda("inverting on zero-field source"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; - getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); - return; - } else if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { - b2 = r2; - } else { - errorQuda("Null vector computing requires non-zero guess!"); - } + if (is_zero_src(x, b, b2)) { + getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); + return; } // Set shadow residual depending if the source vector is directly usable @@ -547,15 +568,16 @@ namespace quda { // Initialize values. for (int i = 1; i <= n_krylov; i++) { blas::zero(r[i]); } - rho0 = 1.0; - alpha = 0.0; - omega = 1.0; - - double stop = !fixed_iteration ? stopping(param.tol, b2, param.residual_type) : 0.0; // stopping condition of solver. + auto stop = !fixed_iteration ? stopping(param.tol, b2, param.residual_type) : + vector(b.size(), 0.0); // stopping condition of solver. const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; - double heavy_quark_res = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x, r_full).z) : 0.0; + vector heavy_quark_res(b.size(), 0.0); + if (use_heavy_quark_res) { + auto hq = blas::HeavyQuarkResidualNorm(x, r_full); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual //bool l2_converge = false; @@ -576,7 +598,7 @@ namespace quda { // Various variables related to reliable updates. int rUpdate = 0; // count reliable updates. double delta = param.delta; // delta for reliable updates. - double rNorm = sqrt(r2); // The current residual norm. + double rNorm = sqrt(r2[0]); // The current residual norm. double maxrr = rNorm; // The maximum residual norm since the last reliable update. double maxrx = rNorm; // The same. Would be different if we did 'x' reliable updates. @@ -584,19 +606,23 @@ namespace quda { while (!convergence(r2, 0.0, stop, 0.0) && total_iter < param.maxiter) { // rho0 = -omega*rho0; - rho0 *= -omega; + for (auto i = 0u; i < b.size(); i++) rho0[i] *= -omega[i]; // BiCG part of calculation. for (int j = 0; j < n_krylov; j++) { // rho1 = , beta = alpha*rho1/rho0, rho0 = rho1; // Can fuse into updateXRend. rho1 = blas::cDotProduct(r0, r[j]); - beta = alpha*rho1/rho0; - rho0 = rho1; + for (auto i = 0u; i < b.size(); i++) { + beta[i] = alpha[i] * rho1[i] / rho0[i]; + rho0[i] = rho1[i]; + } // for i = 0 .. j, u[i] = r[i] - beta*u[i] // All but i = j is hidden in Dslash auxillary work (overlapping comms and compute). - blas::caxpby(1.0, r[j], -beta, u[j]); + std::vector minus_beta(beta.size()); + for (auto i = 0u; i < beta.size(); i++) minus_beta[i] = -beta[i]; + blas::caxpby(1.0, r[j], minus_beta, u[j]); if (j > 0) { dslash::aux_worker = &bicgstabl_update; @@ -611,11 +637,14 @@ namespace quda { // alpha = rho0/ // The machinary isn't there yet, but this could be fused with the matSloppy above. - alpha = rho0/blas::cDotProduct(r0, u[j+1]); + auto r0Tu = blas::cDotProduct(r0, u[j + 1]); + for (auto i = 0u; i < b.size(); i++) alpha[i] = rho0[i] / r0Tu[i]; // for i = 0 .. j, r[i] = r[i] - alpha u[i+1] // All but i = j is hidden in Dslash auxillary work (overlapping comms and compute). - blas::caxpy(-alpha, u[j+1], r[j]); + std::vector minus_alpha(alpha.size()); + for (auto i = 0u; i < alpha.size(); i++) minus_alpha[i] = -alpha[i]; + blas::caxpy(minus_alpha, u[j + 1], r[j]); // We can always at least update x. dslash::aux_worker = &bicgstabl_update; bicgstabl_update.update_j_max(j); @@ -630,11 +659,11 @@ namespace quda { #ifndef LEGACY_MR // Perform the MR portion of BiCGstab-L // if we're doing a fixed number of iterations, we only need to update x - computeMR(x_sloppy, fixed_iteration); + for (auto i = 0u; i < b.size(); i++) computeMR(x_sloppy[i], get_i(u, i), get_i(r, i), fixed_iteration, i); #else // Legacy version matching the BiCGstab-L paper which performs // an explicit Gram-Schmidt for the MR portion - legacyComputeMR(x_sloppy); + for (auto i = 0u; i < b.size(); i++) legacyComputeMR(x_sloppy[i], get_i(u, i), get_i(r, i), i); #endif if (!fixed_iteration) { @@ -644,12 +673,14 @@ namespace quda { // Check the heavy quark residual if we need to. if (use_heavy_quark_res && total_iter % heavy_quark_check == 0) { - if (&x != &x_sloppy) { + if (x.Precision() != x_sloppy[0].Precision()) { blas::copy(temp, y); - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x_sloppy, temp, r[0]).z); + auto hq = blas::xpyHeavyQuarkResidualNorm(x_sloppy, temp, r[0]); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); } else { blas::copy(r_full, r[0]); - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r_full).z); + auto hq = blas::xpyHeavyQuarkResidualNorm(x, y, r_full); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); } } } @@ -660,7 +691,7 @@ namespace quda { // updated (depending on if you're using pipelining or not). In BiCGstab-L, there's only // one place (for now) to get the updated residual, so we just do away with 'updateR'. // Further remark: "reliable" updates rNorm, maxrr, maxrx!! - if (total_iter >= param.maxiter || r2 < stop || reliable(rNorm, maxrx, maxrr, r2, delta)) { + if (total_iter >= param.maxiter || r2 < stop || reliable(rNorm, maxrx, maxrr, r2[0], delta)) { if ((r2 < stop || total_iter >= param.maxiter) && param.sloppy_converge) break; if (mixed()) { blas::copy(x, x_sloppy); } @@ -681,7 +712,7 @@ namespace quda { blas::zero(x_sloppy); // Update rNorm, maxrr, maxrx. - rNorm = sqrt(r2); + rNorm = sqrt(r2[0]); maxrr = rNorm; maxrx = rNorm; @@ -713,9 +744,12 @@ namespace quda { // !param.is_preconditioner comes from bicgstab, param.compute_true_res came from gcr. if (!param.is_preconditioner && param.compute_true_res) { // do not do the below if this is an inner solver. mat(r_full, x); - double true_res = blas::xmyNorm(b, r_full); - param.true_res = sqrt(true_res / b2); - param.true_res_hq = use_heavy_quark_res ? sqrt(blas::HeavyQuarkResidualNorm(x, r[0]).z) : 0.0; + auto true_res = blas::xmyNorm(b, r_full); + auto hq = use_heavy_quark_res ? blas::HeavyQuarkResidualNorm(x, r[0]) : vector(b.size(), {}); + for (auto i = 0u; i < b.size(); i++) { + param.true_res[i] = sqrt(true_res[i] / b2[i]); + param.true_res_hq[i] = sqrt(hq[i].z); + } } getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); From dce23c6a958565b628b5bf82f084ba98b91ac679 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 16 Sep 2024 19:02:56 -0700 Subject: [PATCH 074/103] Fix typo. Closes #1492 --- tests/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a26599564f..823633a408 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1032,7 +1032,7 @@ if(QUDA_DIRAC_TWISTED_MASS) --matpc even-even --enable-testing true --gtest_output=xml:invert_test_twisted_mass_sym.xml) - add_test(NAME invert_test_twisted_mass_asym} + add_test(NAME invert_test_twisted_mass_asym COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} --dslash-type twisted-mass --dim 2 4 6 8 --niter 1000 --ngcrkrylov 8 From 827700d528a70edc8127bf669f782df109effcc2 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 18 Sep 2024 08:53:52 -0700 Subject: [PATCH 075/103] Use fine-grain parallelization for CopySpinor --- include/kernels/copy_color_spinor_mg.cuh | 12 +++--------- lib/copy_color_spinor_mg.in.hpp | 12 ++++-------- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/include/kernels/copy_color_spinor_mg.cuh b/include/kernels/copy_color_spinor_mg.cuh index 32b6d3de87..fefc6a1277 100644 --- a/include/kernels/copy_color_spinor_mg.cuh +++ b/include/kernels/copy_color_spinor_mg.cuh @@ -16,9 +16,7 @@ namespace quda { template CopyArg(ColorSpinorField &out, const ColorSpinorField &in, T1 *Out, T2 *In) : - kernel_param(in.VolumeCB()), - out(out, 1, Out), - in(in, 1, In) + kernel_param(dim3(in.VolumeCB(), nSpin, nColor)), out(out, 1, Out), in(in, 1, In) {} }; @@ -27,13 +25,9 @@ namespace quda { constexpr CopySpinor_(const Arg &arg) : arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } - __device__ __host__ inline void operator()(int x_cb) + __device__ __host__ inline void operator()(int x_cb, int s, int c) { - for (int s=0; s - class CopySpinor : TunableKernel1D { + class CopySpinor : TunableKernel3D + { ColorSpinorField &out; const ColorSpinorField ∈ FloatOut *Out; @@ -27,12 +28,8 @@ namespace quda { unsigned int minThreads() const { return in.VolumeCB(); } public: - CopySpinor(ColorSpinorField &out, const ColorSpinorField &in, QudaFieldLocation location, FloatOut* Out, FloatIn* In) : - TunableKernel1D(in, location), - out(out), - in(in), - Out(Out), - In(In) + CopySpinor(ColorSpinorField &out, const ColorSpinorField &in, QudaFieldLocation location, FloatOut *Out, FloatIn *In) : + TunableKernel3D(in, in.Nspin(), in.Ncolor(), location), out(out), in(in), Out(Out), In(In) { apply(device::get_default_stream()); } @@ -44,7 +41,6 @@ namespace quda { launch(tp, stream, CopyArg(out, in, Out, In)); } - long long flops() const { return 0; } long long bytes() const { return in.Bytes() + out.Bytes(); } }; From ade0e1633e07332230fc4e45919be7917ef6f9b4 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 18 Sep 2024 09:17:04 -0700 Subject: [PATCH 076/103] Ensure that mg_eig_evals_batch_size in test code has sensible default --- tests/utils/set_params.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/utils/set_params.cpp b/tests/utils/set_params.cpp index bc124c6c31..445a2a0b28 100644 --- a/tests/utils/set_params.cpp +++ b/tests/utils/set_params.cpp @@ -795,6 +795,8 @@ void setMultigridEigParam(QudaEigParam &mg_eig_param, int level) mg_eig_param.n_ev = mg_eig_n_ev[level]; mg_eig_param.n_kr = mg_eig_n_kr[level]; mg_eig_param.n_conv = nvec[level]; + mg_eig_param.compute_evals_batch_size + = mg_eig_evals_batch_size[level] ? mg_eig_evals_batch_size[level] : eig_evals_batch_size; // Inverters will deflate only this number of vectors. if (mg_eig_n_ev_deflate[level] > 0 && mg_eig_n_ev_deflate[level] < mg_eig_param.n_conv) From 36ba139856a8a7e25a9734a509a5084457aecb6c Mon Sep 17 00:00:00 2001 From: Evan Weinberg Date: Mon, 23 Sep 2024 12:37:15 -0700 Subject: [PATCH 077/103] Updated MILC interface to batched CG, hq tolerance bugfix in CG itself --- lib/inv_cg_quda.cpp | 15 ++++++++------- lib/milc_interface.cpp | 16 ++++++++++++---- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index 5531639df1..0dd8dbb117 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -501,6 +501,7 @@ namespace quda { getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop_hq = std::vector(b.size(), param.tol_hq); auto get_hq_res = [](cvector_ref &x, cvector_ref &r) { auto hq_nrm = blas::HeavyQuarkResidualNorm(x, r); @@ -529,7 +530,7 @@ namespace quda { PrintStats("CG", k, r2, b2, hq_res); - bool converged = convergence(r2, hq_res, stop, param.tol_hq); + bool converged = convergence(r2, hq_res, stop, stop_hq); // Various parameters related to restarts @@ -601,7 +602,7 @@ namespace quda { // we're still checking the L2 norm, or if that has converged/broken down and we're // now looking at the HQ residual. - if (!L2breakdown && (L2_required || convergenceL2(hq_res, param.tol_hq))) { + if (!L2breakdown && (L2_required || convergenceL2(r2, stop))) { // L2 based reliable update // If the iterated residual norm has gone above the most recent "baseline" norm, @@ -633,10 +634,10 @@ namespace quda { } // force a reliable update if we are within target tolerance (only if doing reliable updates) - if (convergence(r2, hq_res, stop, param.tol_hq) && param.delta >= param.tol) updateX = true; + if (convergence(r2, hq_res, stop, stop_hq) && param.delta >= param.tol) updateX = true; // force a reliable update based on the HQ residual if L2 breakdown has already happened - if (L2breakdown && (convergenceHQ(hq_res, param.tol_hq) || (r2[0] / b2[0]) < hq_res_stall_check) + if (L2breakdown && (convergenceHQ(hq_res, stop_hq) || (r2[0] / b2[0]) < hq_res_stall_check) && param.delta >= param.tol) updateX = true; @@ -812,14 +813,14 @@ namespace quda { PrintStats("CG", k, r2, b2, hq_res); // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently - converged = convergence(r2, hq_res, stop, param.tol_hq); + converged = convergence(r2, hq_res, stop, stop_hq); // check for recent enough reliable updates of the HQ residual if we use it // L2 is converged or precision maxed out for L2 bool L2done = L2breakdown || convergenceL2(r2, stop); // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update - bool HQdone = (steps_since_reliable == 0 && param.delta > 0) && convergenceHQ(hq_res, param.tol_hq); + bool HQdone = (steps_since_reliable == 0 && param.delta > 0) && convergenceHQ(hq_res, stop_hq); converged = L2done && HQdone; } @@ -846,7 +847,7 @@ namespace quda { } } - PrintSummary("CG", k, r2, b2, stop, param.tol_hq); + PrintSummary("CG", k, r2, b2, stop, stop_hq); getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); } diff --git a/lib/milc_interface.cpp b/lib/milc_interface.cpp index 99d0c1b89c..24443b0a1f 100644 --- a/lib/milc_interface.cpp +++ b/lib/milc_interface.cpp @@ -1424,11 +1424,19 @@ void qudaInvertMsrc(int external_precision, int quda_precision, double mass, Qud host_free(sln_pointer); host_free(src_pointer); - // return the number of iterations taken by the inverter + // The conventions for num_iters, final_residual, and final_fermilab_residual are taken from the + // convention in `generic_ks/d_congrad5_fn_milc.c` (commit 414fb31). Here, a block solve + // is emulated as a series of sequential solves. Each individual solve overrides the + // final tolerance and iteration counts from the previous solve. Therefore, num_iters + // as well as the tolerances come from the last solve. + + // invertParam.iter is the total number of iterations for the block solver, which is ~= + // to the number of iterations the last rhs would take. *num_iters = invertParam.iter; - // FIXME MILC seems to only care about a single residual? - *final_residual = invertParam.true_res[0]; - *final_fermilab_residual = invertParam.true_res_hq[0]; + + // MILC only cares about a single residual, which happens to be the last one as described above. + *final_residual = invertParam.true_res[num_src - 1]; + *final_fermilab_residual = invertParam.true_res_hq[num_src - 1]; if (!create_quda_gauge) invalidateGaugeQuda(); From 50e78799c7b9e5babb2ced06bcf2f27e35f036e3 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Thu, 26 Sep 2024 13:51:23 -0700 Subject: [PATCH 078/103] PCG is now multi-RHS ready. Improve robustness of Solver convergence and print helpers to catch mis-sized or max-matched vector lengths --- include/invert_quda.h | 61 ++++----- include/reliable_updates.h | 5 +- lib/inv_cg_quda.cpp | 36 +++--- lib/inv_mr_quda.cpp | 2 +- lib/inv_pcg_quda.cpp | 248 +++++++++++++++++++++---------------- lib/madwf_ml.cpp | 4 +- lib/solver.cpp | 31 ++++- 7 files changed, 211 insertions(+), 176 deletions(-) diff --git a/include/invert_quda.h b/include/invert_quda.h index 1acb9c67ef..891a5fa947 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -14,21 +14,6 @@ namespace quda { -// temporary addition until multi-RHS for all Dirac operator functions -#ifdef __CUDACC__ -#ifdef __NVCC_DIAG_PRAGMA_SUPPORT__ -#pragma nv_diag_suppress 611 -#pragma nv_diag_suppress 997 -#else -#pragma diag_suppress 611 -#pragma diag_suppress 997 -#endif -#endif - -#ifdef __NVCOMPILER -#pragma diag_suppress partial_override -#endif - /** SolverParam is the meta data used to define linear solvers. */ @@ -473,10 +458,10 @@ namespace quda { @brief a virtual method that performs the inversion and collect some vectors. The default here is a no-op and should not be called. */ - virtual void solve_and_collect(ColorSpinorField &, const ColorSpinorField &, cvector_ref &, int, - double) + virtual void solve_and_collect(cvector_ref &, cvector_ref &, + cvector_ref &, int, double) { - errorQuda("NOT implemented."); + errorQuda("Not implemented."); } void set_tol(double tol) { param.tol = tol; } @@ -909,21 +894,19 @@ namespace quda { virtual QudaInverterType getInverterType() const override { return QUDA_CG3_INVERTER; } }; - class PreconCG : public Solver { - private: + class PCG : public Solver + { std::shared_ptr K; SolverParam Kparam; // parameters for preconditioner solve - ColorSpinorField r; - ColorSpinorField y; - ColorSpinorField Ap; - ColorSpinorField x_sloppy; - ColorSpinorField r_sloppy; - ColorSpinorField minvr; - ColorSpinorField minvr_sloppy; - ColorSpinorField minvr_pre; - ColorSpinorField r_pre; - XUpdateBatch x_update_batch; + std::vector r; + std::vector y; + std::vector Ap; + std::vector x_sloppy; + std::vector r_sloppy; + std::vector minvr_sloppy; + std::vector minvr_pre; + std::vector r_pre; int Np; /** the size of the accumulator pipeline */ bool init = false; @@ -933,11 +916,11 @@ namespace quda { @param[in] x Solution vector @param[in] b Source vector */ - void create(ColorSpinorField &x, const ColorSpinorField &b); + void create(cvector_ref &x, cvector_ref &b); public: - PreconCG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m); + PCG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, + SolverParam ¶m); /** * @brief Preconditioned CG supporting a pre-existing preconditioner K. @@ -948,15 +931,14 @@ namespace quda { * @param matEig Deflation precision Dirac matrix * @param param Solver parameters */ - PreconCG(const DiracMatrix &mat, Solver &K, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m); + PCG(const DiracMatrix &mat, Solver &K, const DiracMatrix &matSloppy, const DiracMatrix &matoPrecon, + const DiracMatrix &matEig, SolverParam ¶m); - virtual ~PreconCG(); + virtual ~PCG(); void operator()(cvector_ref &out, cvector_ref &in) override { - for (auto i = 0u; i < in.size(); i++) - this->solve_and_collect(out[i], in[i], cvector_ref(), 0, 0); + solve_and_collect(out, in, {}, 0, 0); } /** @@ -967,7 +949,7 @@ namespace quda { @param collect_miniter minimal iteration start from which the r vectors are to be collected @param collect_tol maxiter tolerance start from which the r vectors are to be collected */ - virtual void solve_and_collect(ColorSpinorField &out, const ColorSpinorField &in, + virtual void solve_and_collect(cvector_ref &out, cvector_ref &in, cvector_ref &v_r, int collect_miniter, double collect_tol) override; virtual bool hermitian() const override { return true; } /** PCG is only Hermitian system */ @@ -975,7 +957,6 @@ namespace quda { virtual QudaInverterType getInverterType() const final { return QUDA_PCG_INVERTER; } }; - class BiCGstab : public Solver { private: diff --git a/include/reliable_updates.h b/include/reliable_updates.h index a1ddd9f122..ec70fb4f3c 100644 --- a/include/reliable_updates.h +++ b/include/reliable_updates.h @@ -163,8 +163,9 @@ namespace quda pnorm = pnorm + alpha * alpha * ppnorm; xnorm = sqrt(pnorm); d_new = d + params.u * rNorm + params.uhigh * params.Anorm * xnorm; - if (steps_since_reliable == 0 && getVerbosity() >= QUDA_DEBUG_VERBOSE) - printfQuda("New dnew: %e (r %e , y %e)\n", d_new, params.u * rNorm, params.uhigh * params.Anorm * xnorm); + if (steps_since_reliable == 0) + logQuda(QUDA_DEBUG_VERBOSE, "New dnew: %e (r %e , y %e)\n", d_new, params.u * rNorm, + params.uhigh * params.Anorm * xnorm); } steps_since_reliable++; } diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index 0dd8dbb117..43df31e1b2 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -71,7 +71,7 @@ namespace quda { } const int Np = (param.solution_accumulator_pipeline == 0 ? 1 : param.solution_accumulator_pipeline); - if (Np < 0 || Np > 16) errorQuda("Invalid value %d for solution_accumulator_pipeline\n", Np); + if (Np < 0 || Np > 16) errorQuda("Invalid value %d for solution_accumulator_pipeline", Np); // Determine whether or not we're doing a heavy quark residual const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; @@ -86,8 +86,6 @@ namespace quda { // just in case HQ residual solves are split into a separate file if (use_heavy_quark_res) errorQuda("The \"vanilla\" CG solver does not support HQ residual solves"); - // whether to select alternative reliable updates - bool alternative_reliable = param.use_alternative_reliable; /** When CG is used as a preconditioner, and we disable the `advanced features`, these features are turned off: - Reliable updates @@ -99,7 +97,10 @@ namespace quda { if (!param.is_preconditioner) getProfile().TPSTART(QUDA_PROFILE_INIT); - vector b2 = blas::norm2(b); + // whether to select alternative reliable updates + bool alternative_reliable = param.use_alternative_reliable; + + auto b2 = blas::norm2(b); // Check to see that we're not trying to invert on a zero-field source if (is_zero_src(x, b, b2)) { @@ -165,8 +166,7 @@ namespace quda { if (param.use_sloppy_partial_accumulator) blas::zero(x_sloppy); blas::copy(r_sloppy, r); - ColorSpinorParam csParam(r_sloppy[0]); - csParam.create = QUDA_NULL_FIELD_CREATE; + auto csParam(r_sloppy[0]); std::vector x_update_batch(b.size()); for (auto i = 0u; i < b.size(); i++) x_update_batch[i] = XUpdateBatch(Np, !p_init[i].empty() ? p_init[i] : r_sloppy[i], csParam); @@ -199,7 +199,7 @@ namespace quda { int k = 0; - PrintStats("CG", k, r2, b2, 0.0); + PrintStats("CG", k, r2, b2); bool converged = convergenceL2(r2, stop); @@ -237,7 +237,6 @@ namespace quda { matSloppy(Ap, p); vector sigma(b.size()); - ; bool breakdown = false; if (advanced_feature && param.pipeline) { @@ -312,13 +311,15 @@ namespace quda { errorQuda("Not implemented pipelined CG with Np > 1"); } } else { + if (Np == 1) { // with Np=1 we just run regular fusion between x and p updates blas::axpyZpbx(get_alpha(x_update_batch), p, x_sloppy, r_sloppy, beta); } else { - for (auto i = 0u; i < b.size(); i++) - if (x_update_batch[i].is_container_full()) { x_update_batch[i].accumulate_x(x_sloppy[i]); } + for (auto i = 0u; i < b.size(); i++) { + if (x_update_batch[i].is_container_full()) x_update_batch[i].accumulate_x(x_sloppy[i]); + } // p[(k+1)%Np] = r + beta * p[k%Np] blas::xpayz(r_sloppy, beta, p, p_next); @@ -333,9 +334,8 @@ namespace quda { x_update_batch[i].accumulate_x(x_sloppy[i]); x_update_batch[i].reset_next(); } - blas::copy(x, x_sloppy); // nop when these pointers alias + blas::xpy(x_sloppy, y); // swap these around? - blas::xpy(x, y); // swap these around? mat(r, y); // here we can use x as tmp r2 = blas::xmyNorm(b, r); @@ -378,7 +378,7 @@ namespace quda { breakdown = false; k++; - PrintStats("CG", k, r2, b2, 0.0); + PrintStats("CG", k, r2, b2); // check convergence converged = convergenceL2(r2, stop); @@ -421,7 +421,7 @@ namespace quda { } } - PrintSummary("CG", k, r2, b2, stop, 0.0); + PrintSummary("CG", k, r2, b2, stop); if (!param.is_preconditioner) getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); @@ -1001,7 +1001,7 @@ namespace quda { int k = 0; - PrintStats("CG", k, r2avg / param.num_src, b2avg, 0.); + PrintStats("CG", k, r2avg / param.num_src, b2avg); bool allconverged = true; bool converged[QUDA_MAX_MULTI_SHIFT]; for (int i = 0; i < param.num_src; i++) { @@ -1131,7 +1131,7 @@ namespace quda { } k++; - PrintStats("CG", k, r2avg / param.num_src, b2avg, 0); + PrintStats("CG", k, r2avg / param.num_src, b2avg); // check convergence allconverged = true; for (int i = 0; i < param.num_src; i++) { @@ -1159,7 +1159,7 @@ namespace quda { param.true_res_offset[i] = param.true_res; param.true_res_hq_offset[i] = param.true_res_hq; - PrintSummary("CG", k, r2(i, i).real(), b2[i], stop[i], 0.0); + PrintSummary("CG", k, r2(i, i).real(), b2[i], stop[i]); } getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); @@ -1717,7 +1717,7 @@ void CG::solve(ColorSpinorField& x, ColorSpinorField& b) { param.true_res_offset[i] = param.true_res; param.true_res_hq_offset[i] = param.true_res_hq; - PrintSummary("CG", k, r2(i,i).real(), b2[i], stop[i], 0.0); + PrintSummary("CG", k, r2(i, i).real(), b2[i], stop[i]); } getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); diff --git a/lib/inv_mr_quda.cpp b/lib/inv_mr_quda.cpp index 032950e655..566ff5661e 100644 --- a/lib/inv_mr_quda.cpp +++ b/lib/inv_mr_quda.cpp @@ -149,7 +149,7 @@ namespace quda for (auto i = 0u; i < b2.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); converged = (step < param.Nsteps && r2 > stop) ? false : true; if (!converged) blas::copy(r_sloppy, r); - PrintStats("MR (restart)", iter, r2, b2, 0.0); + PrintStats("MR (restart)", iter, r2, b2); } else { blas::ax(scale, r_sloppy); r2 = blas::norm2(r_sloppy); diff --git a/lib/inv_pcg_quda.cpp b/lib/inv_pcg_quda.cpp index 4b15b58ac2..57244a04fb 100644 --- a/lib/inv_pcg_quda.cpp +++ b/lib/inv_pcg_quda.cpp @@ -18,8 +18,8 @@ namespace quda using namespace blas; - PreconCG::PreconCG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m) : + PCG::PCG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, + const DiracMatrix &matEig, SolverParam ¶m) : Solver(mat, matSloppy, matPrecon, matEig, param), K(nullptr), Kparam(param) { fillInnerSolverParam(Kparam, param); @@ -30,8 +30,8 @@ namespace quda K = createPreconditioner(matPrecon, matPrecon, matPrecon, matEig, param, Kparam); } - PreconCG::PreconCG(const DiracMatrix &mat, Solver &K_, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, - const DiracMatrix &matEig, SolverParam ¶m) : + PCG::PCG(const DiracMatrix &mat, Solver &K_, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, + const DiracMatrix &matEig, SolverParam ¶m) : Solver(mat, matSloppy, matPrecon, matEig, param), K(nullptr), Kparam(param) { fillInnerSolverParam(Kparam, param); @@ -39,7 +39,7 @@ namespace quda K = wrapExternalPreconditioner(K_); } - PreconCG::~PreconCG() + PCG::~PCG() { getProfile().TPSTART(QUDA_PROFILE_FREE); @@ -49,84 +49,80 @@ namespace quda getProfile().TPSTOP(QUDA_PROFILE_FREE); } - void PreconCG::create(ColorSpinorField &x, const ColorSpinorField &b) + void PCG::create(cvector_ref &x, cvector_ref &b) { Solver::create(x, b); - if (!init) { + if (!init || r.size() != b.size()) { getProfile().TPSTART(QUDA_PROFILE_INIT); - ColorSpinorParam csParam(b); + ColorSpinorParam csParam(b[0]); - r = ColorSpinorField(b); - if (K) minvr = ColorSpinorField(b); + resize(r, b.size(), csParam); csParam.create = QUDA_ZERO_FIELD_CREATE; - y = ColorSpinorField(csParam); + resize(y, b.size(), csParam); // create sloppy fields csParam.setPrecision(param.precision_sloppy); csParam.create = QUDA_NULL_FIELD_CREATE; - Ap = ColorSpinorField(csParam); + resize(Ap, b.size(), csParam); - x_sloppy = (!mixed() || !param.use_sloppy_partial_accumulator) ? - x.create_alias() : ColorSpinorField(csParam); + if (!mixed() || !param.use_sloppy_partial_accumulator) { + create_alias(x_sloppy, x); + } else { + resize(x_sloppy, b.size(), csParam); + } - csParam.create = QUDA_COPY_FIELD_CREATE; - csParam.field = &r; - r_sloppy = !mixed() ? r.create_alias() : ColorSpinorField(csParam); + if (!mixed()) { + create_alias(r_sloppy, r); + } else { + resize(r_sloppy, b.size(), csParam); + } if (K) { - csParam.field = &minvr; - minvr_sloppy = !mixed() ? minvr.create_alias() : ColorSpinorField(csParam); + resize(minvr_sloppy, b.size(), csParam); // create preconditioner intermediates - csParam.create = QUDA_NULL_FIELD_CREATE; csParam.setPrecision(Kparam.precision); - r_pre = ColorSpinorField(csParam); + resize(r_pre, b.size(), csParam); // Create minvr_pre - minvr_pre = ColorSpinorField(csParam); + resize(minvr_pre, b.size(), csParam); } Np = (param.solution_accumulator_pipeline == 0 ? 1 : param.solution_accumulator_pipeline); if (Np < 0 || Np > 16) errorQuda("Invalid value %d for solution_accumulator_pipeline", Np); - csParam.create = QUDA_NULL_FIELD_CREATE; - csParam.setPrecision(param.precision_sloppy); - x_update_batch = XUpdateBatch(Np, K ? minvr_sloppy : r_sloppy, csParam); - getProfile().TPSTOP(QUDA_PROFILE_INIT); init = true; } } - void PreconCG::solve_and_collect(ColorSpinorField &x, const ColorSpinorField &b, cvector_ref &v_r, - int collect_miniter, double collect_tol) + void PCG::solve_and_collect(cvector_ref &x, cvector_ref &b, + cvector_ref &v_r, int collect_miniter, double collect_tol) { - if (K) K->train_param(*this, b); + if (K) K->train_param(*this, b[0]); - create(x, b); + if (v_r.size() && x.size() > 1) errorQuda("Collect not supported for multi-RHS PCG"); getProfile().TPSTART(QUDA_PROFILE_INIT); // whether to select alternative reliable updates bool alternative_reliable = param.use_alternative_reliable; - double b2 = blas::norm2(b); + auto b2 = blas::norm2(b); // Check to see that we're not trying to invert on a zero-field source - if (b2 == 0 && param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { + if (is_zero_src(x, b, b2)) { getProfile().TPSTOP(QUDA_PROFILE_INIT); - warningQuda("Warning: inverting on zero-field source"); - x = b; - param.true_res = 0.0; - param.true_res_hq = 0.0; return; } + create(x, b); + if (param.deflate) { // Construct the eigensolver and deflation space if requested. - constructDeflationSpace(b, matEig); + constructDeflationSpace(b[0], matEig); if (deflate_compute) { // compute the deflation space. (*eig_solve)(evecs, evals); @@ -138,22 +134,23 @@ namespace quda } } - double Anorm = 0; + double Anorm = 0.0; // for alternative reliable updates if (alternative_reliable) { // estimate norm for reliable updates - mat(r, b); - Anorm = sqrt(blas::norm2(r) / b2); + mat(r[0], b[0]); + Anorm = sqrt(norm2(r[0]) / b2[0]); } // compute initial residual - double r2 = 0.0; + vector r2(b.size(), 0.0); if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { // Compute r = b - A * x mat(r, x); r2 = blas::xmyNorm(b, r); - if (b2 == 0) b2 = r2; + for (auto i = 0u; i < b.size(); i++) + if (b2[i] == 0) b2[i] = r2[i]; // y contains the original guess. blas::copy(y, x); } else { @@ -170,34 +167,44 @@ namespace quda } blas::zero(x); - if (&x != &x_sloppy) blas::zero(x_sloppy); + if (param.use_sloppy_partial_accumulator) blas::zero(x_sloppy); + if (r_sloppy[0].Precision() != r[0].Precision()) blas::copy(r_sloppy, r); + + auto csParam(r_sloppy[0]); + std::vector x_update_batch(b.size()); + for (auto i = 0u; i < b.size(); i++) + x_update_batch[i] = XUpdateBatch(Np, K ? minvr_sloppy[i] : r_sloppy[i], csParam); const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; if (K) { - r_pre = r_sloppy; + blas::copy(r_pre, r_sloppy); pushVerbosity(param.verbosity_precondition); (*K)(minvr_pre, r_pre); popVerbosity(); - minvr_sloppy = minvr_pre; + blas::copy(minvr_sloppy, minvr_pre); } getProfile().TPSTOP(QUDA_PROFILE_INIT); getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); - double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver - double heavy_quark_res = 0.0; // heavy quark residual - if (use_heavy_quark_res) heavy_quark_res = sqrt(HeavyQuarkResidualNorm(x, r).z); + auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop_hq = std::vector(b.size(), param.tol_hq); - double beta = 0.0; - double pAp; - double rMinvr = 0; - double rMinvr_old = 0.0; - double r_new_Minvr_old = 0.0; - double r2_old = 0; - r2 = norm2(r); + std::vector heavy_quark_res(b.size(), 0.0); // heavy quark residual + if (use_heavy_quark_res) { + auto hq = HeavyQuarkResidualNorm(x, r); + for (auto i = 0u; i < b.size(); i++) heavy_quark_res[i] = sqrt(hq[i].z); + } + + std::vector beta(b.size(), 0.0); + std::vector pAp(b.size(), 0.0); + std::vector rMinvr(b.size(), 0.0); + std::vector rMinvr_old(b.size(), 0.0); + std::vector r_new_Minvr_old(b.size(), 0.0); + std::vector r2_old(b.size(), 0.0); - if (K) rMinvr = reDotProduct(r_sloppy, minvr_sloppy); + if (K) { rMinvr = reDotProduct(r_sloppy, minvr_sloppy); } getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); getProfile().TPSTART(QUDA_PROFILE_COMPUTE); @@ -221,44 +228,62 @@ namespace quda ru_params.hqmaxresIncrease = param.max_hq_res_increase; ru_params.hqmaxresRestartTotal = param.max_hq_res_restart_total; - ReliableUpdates ru(ru_params, r2); + ReliableUpdates ru(ru_params, r2[0]); - bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + bool converged = convergence(r2, heavy_quark_res, stop, stop_hq); - while (!converged && k < param.maxiter) { + auto get_p = [](std::vector &x_update_batch, bool next = false) { + vector_ref p; + p.reserve(x_update_batch.size()); + for (auto &x : x_update_batch) p.push_back(next ? x.get_next_field() : x.get_current_field()); + return p; + }; - matSloppy(Ap, x_update_batch.get_current_field()); + auto get_alpha = [](std::vector &x_update_batch) { + vector alpha; + alpha.reserve(x_update_batch.size()); + for (auto &x : x_update_batch) alpha.push_back(x.get_current_alpha()); + return alpha; + }; + + while (!converged && k < param.maxiter) { + auto p = get_p(x_update_batch); + auto p_next = get_p(x_update_batch, true); + matSloppy(Ap, p); - double sigma; // alternative reliable updates, if (alternative_reliable) { - double3 pAppp = blas::cDotProductNormA(x_update_batch.get_current_field(), Ap); - pAp = pAppp.x; - ru.update_ppnorm(pAppp.z); + auto pAppp = blas::cDotProductNormA(p, Ap); + for (auto i = 0u; i < b.size(); i++) pAp[i] = pAppp[i].x; + ru.update_ppnorm(pAppp[0].z); } else { - pAp = reDotProduct(x_update_batch.get_current_field(), Ap); + pAp = reDotProduct(p, Ap); } - x_update_batch.get_current_alpha() = (K) ? rMinvr / pAp : r2 / pAp; - double2 cg_norm = axpyCGNorm(-x_update_batch.get_current_alpha(), Ap, r_sloppy); + for (auto i = 0u; i < b.size(); i++) + x_update_batch[i].get_current_alpha() = K ? rMinvr[i] / pAp[i] : r2[i] / pAp[i]; + + auto cg_norm = axpyCGNorm(-get_alpha(x_update_batch), Ap, r_sloppy); // r --> r - alpha*A*p r2_old = r2; - r2 = cg_norm.x; - sigma = cg_norm.y >= 0.0 ? cg_norm.y : r2; // use r2 if (r_k+1, r_k-1 - r_k) breaks + vector sigma(b.size()); + for (auto i = 0u; i < b.size(); i++) { + r2[i] = cg_norm[i].x; + sigma[i] = cg_norm[i].y >= 0.0 ? cg_norm[i].y : r2[i]; // use r2 if (r_k+1, r_k-1 - r_k) breaks + } if (K) rMinvr_old = rMinvr; - ru.update_rNorm(sqrt(r2)); - - ru.evaluate(r2_old); + ru.update_rNorm(sqrt(r2[0])); + ru.evaluate(r2_old[0]); // force a reliable update if we are within target tolerance (only if doing reliable updates) - if (convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol) ru.set_updateX(); + if (convergence(r2, heavy_quark_res, stop, stop_hq) && param.delta >= param.tol) ru.set_updateX(); - if (collect > 0 && k > collect_miniter && r2 < collect_tol * collect_tol * b2) { - v_r[v_r.size() - collect] = r_sloppy; - logQuda(QUDA_VERBOSE, "Collecting r %2d: r2 / b2 = %12.8e, k = %5d\n", collect, sqrt(r2 / b2), k); + if (collect > 0 && k > collect_miniter && r2[0] < collect_tol * collect_tol * b2[0]) { + blas::copy(v_r[v_r.size() - collect], r_sloppy); + logQuda(QUDA_VERBOSE, "Collecting r %2d: r2 / b2 = %12.8e, k = %5d\n", collect, sqrt(r2[0] / b2[0]), k); collect--; } @@ -277,35 +302,36 @@ namespace quda minvr_sloppy = minvr_pre; rMinvr = reDotProduct(r_sloppy, minvr_sloppy); - beta = (rMinvr - r_new_Minvr_old) / rMinvr_old; + for (auto i = 0u; i < b.size(); i++) beta[i] = (rMinvr[i] - r_new_Minvr_old[i]) / rMinvr_old[i]; } else { - beta = sigma / r2_old; // use the alternative beta computation + for (auto i = 0u; i < b.size(); i++) beta[i] = sigma[i] / r2_old[i]; // use the alternative beta computation } if (Np == 1) { - axpyZpbx(x_update_batch.get_current_alpha(), x_update_batch.get_current_field(), x_sloppy, - K ? minvr_sloppy : r_sloppy, beta); + axpyZpbx(get_alpha(x_update_batch), p, x_sloppy, K ? minvr_sloppy : r_sloppy, beta); } else { - if (x_update_batch.is_container_full()) { x_update_batch.accumulate_x(x_sloppy); } - blas::xpayz(K ? minvr_sloppy : r_sloppy, beta, x_update_batch.get_current_field(), - x_update_batch.get_next_field()); + for (auto i = 0u; i < b.size(); i++) { + if (x_update_batch[i].is_container_full()) x_update_batch[i].accumulate_x(x_sloppy[i]); + } + blas::xpayz(K ? minvr_sloppy : r_sloppy, beta, p, p_next); } - ru.accumulate_norm(x_update_batch.get_current_alpha()); + ru.accumulate_norm(get_alpha(x_update_batch)[0]); } else { // reliable update // Now that we are performing reliable update, need to update x with the p's that have // not been used yet - x_update_batch.accumulate_x(x_sloppy); - x_update_batch.reset_next(); - + for (auto i = 0u; i < b.size(); i++) { + x_update_batch[i].accumulate_x(x_sloppy[i]); + x_update_batch[i].reset_next(); + } xpy(x_sloppy, y); // y += x // Now compute r mat(r, y); r2 = xmyNorm(b, r); - if (param.deflate && sqrt(r2) < ru.maxr_deflate * param.tol_restart) { + if (param.deflate && sqrt(r2[0]) < ru.maxr_deflate * param.tol_restart) { // Deflate and accumulate to solution vector eig_solve->deflate(y, r, evecs, evals, true); @@ -313,7 +339,7 @@ namespace quda mat(r, y); r2 = blas::xmyNorm(b, r); - ru.update_maxr_deflate(r2); + ru.update_maxr_deflate(r2[0]); } copy(r_sloppy, r); @@ -321,11 +347,13 @@ namespace quda bool L2breakdown = false; double L2breakdown_eps = 0; - if (ru.reliable_break(r2, stop, L2breakdown, L2breakdown_eps)) { break; } + if (ru.reliable_break(r2[0], stop[0], L2breakdown, L2breakdown_eps)) { break; } - ru.update_norm(r2, y); + ru.update_norm(r2[0], y[0]); + ru.reset(r2[0]); - ru.reset(r2); + auto p = get_p(x_update_batch); + auto p_next = get_p(x_update_batch, true); if (K) { // can fuse these two kernels @@ -340,31 +368,35 @@ namespace quda minvr_sloppy = minvr_pre; rMinvr = reDotProduct(r_sloppy, minvr_sloppy); - beta = (rMinvr - r_new_Minvr_old) / rMinvr_old; + for (auto i = 0u; i < b.size(); i++) beta[i] = (rMinvr[i] - r_new_Minvr_old[i]) / rMinvr_old[i]; } else { // standard CG - no preconditioning // explicitly restore the orthogonality of the gradient vector - double rp = reDotProduct(r_sloppy, x_update_batch.get_current_field()) / (r2); - axpy(-rp, r_sloppy, x_update_batch.get_current_field()); + auto rp = cDotProduct(r_sloppy, p); + for (auto i = 0u; i < b.size(); i++) rp[i] /= r2[i]; + caxpy(-rp, r_sloppy, p); - beta = r2 / r2_old; + for (auto i = 0u; i < b.size(); i++) beta[i] = r2[i] / r2_old[i]; } - xpayz(K ? minvr_sloppy : r_sloppy, beta, x_update_batch.get_current_field(), x_update_batch.get_next_field()); + xpayz(K ? minvr_sloppy : r_sloppy, beta, p, p_next); } - ++k; + k++; PrintStats("PCG", k, r2, b2, heavy_quark_res); - converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + converged = convergence(r2, heavy_quark_res, stop, stop_hq); + // if we have converged and need to update any trailing solutions - if ((converged || k == param.maxiter) && ru.steps_since_reliable > 0 && !x_update_batch.is_container_full()) { - x_update_batch.accumulate_x(x_sloppy); - } + for (auto i = 0u; i < b.size(); i++) { + if ((converged || k == param.maxiter) && ru.steps_since_reliable > 0 && !x_update_batch[i].is_container_full()) { + x_update_batch[i].accumulate_x(x_sloppy[i]); + } - if (ru.steps_since_reliable == 0) { - x_update_batch.reset(); - } else { - ++x_update_batch; + if (ru.steps_since_reliable == 0) { + x_update_batch[i].reset(); + } else { + ++x_update_batch[i]; + } } } @@ -383,8 +415,8 @@ namespace quda // compute the true residual mat(r, x); - double true_res = xmyNorm(b, r); - param.true_res = sqrt(true_res / b2); + auto true_res = xmyNorm(b, r); + for (auto i = 0u; i < b.size(); i++) param.true_res[i] = sqrt(true_res[i] / b2[i]); getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); } diff --git a/lib/madwf_ml.cpp b/lib/madwf_ml.cpp index adb7ee323f..1df1ebd706 100644 --- a/lib/madwf_ml.cpp +++ b/lib/madwf_ml.cpp @@ -92,9 +92,9 @@ namespace quda if (getVerbosity() >= QUDA_VERBOSE) { printfQuda("Generating Null Space Vectors ... \n"); } spinorNoise(null_b, rng, QUDA_NOISE_GAUSS); - std::vector B(16); csParam.setPrecision(prec_precondition); - for (auto &pB : B) { pB = ColorSpinorField(csParam); } + std::vector B; + resize(B, 16, csParam); getProfile().TPSTOP(QUDA_PROFILE_INIT); null.solve_and_collect(null_x, null_b, B, param.madwf_null_miniter, param.madwf_null_tol); diff --git a/lib/solver.cpp b/lib/solver.cpp index da1128af09..8fc833beba 100644 --- a/lib/solver.cpp +++ b/lib/solver.cpp @@ -106,9 +106,9 @@ namespace quda { static_cast(param.preconditioner)->mg; // FIXME dirty hack to ensure that preconditioner precision set in interface isn't used in the outer GCR-MG solver if (!param.mg_instance) param.precision_precondition = param.precision_sloppy; - solver = new PreconCG(mat, *(mg), matSloppy, matPrecon, matEig, param); + solver = new PCG(mat, *(mg), matSloppy, matPrecon, matEig, param); } else { - solver = new PreconCG(mat, matSloppy, matPrecon, matEig, param); + solver = new PCG(mat, matSloppy, matPrecon, matEig, param); } break; case QUDA_BICGSTABL_INVERTER: @@ -399,6 +399,10 @@ namespace quda { bool Solver::convergence(cvector &r2, cvector &hq2, cvector &r2_tol, cvector &hq_tol) { + if (r2.size() != hq2.size() || r2.size() != r2_tol.size() || r2.size() != hq_tol.size()) + errorQuda("Mismatched vector lengths r2 = %lu hq2 = %lu r2_tol = %lu hq_tol = %lu", r2.size(), hq2.size(), + r2_tol.size(), hq_tol.size()); + for (auto i = 0u; i < r2.size(); i++) { // check the heavy quark residual norm if necessary if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { @@ -418,6 +422,9 @@ namespace quda { bool Solver::convergenceHQ(cvector &hq2, cvector &hq_tol) { + if (hq2.size() != hq_tol.size()) + errorQuda("Mismatched vector lengths hq2 = %lu hq_tol = %lu", hq2.size(), hq_tol.size()); + for (auto i = 0u; i < hq2.size(); i++) { // check the heavy quark residual norm if necessary if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { @@ -431,6 +438,9 @@ namespace quda { bool Solver::convergenceL2(cvector &r2, cvector &r2_tol) { + if (r2.size() != r2_tol.size()) + errorQuda("Mismatched vector lengths r2 = %lu r2_tol = %lu", r2.size(), r2_tol.size()); + for (auto i = 0u; i < r2.size(); i++) { // check the L2 relative residual norm if necessary if ((param.residual_type & QUDA_L2_RELATIVE_RESIDUAL) || (param.residual_type & QUDA_L2_ABSOLUTE_RESIDUAL)) { @@ -448,7 +458,12 @@ namespace quda { return rhs_str; } - void Solver::PrintStats(const char* name, int k, cvector &r2, cvector &b2, cvector &hq2) { + void Solver::PrintStats(const char *name, int k, cvector &r2, cvector &b2, cvector &hq2_) + { + auto hq2 = hq2_.size() == 0 ? vector(r2.size(), 0.0) : hq2_; + if (r2.size() != b2.size() || r2.size() != hq2.size()) + errorQuda("Mismatched vector lengths r2 = %lu b2 = %lu hq2 = %lu", r2.size(), b2.size(), hq2.size()); + for (auto i = 0u; i < r2.size(); i++) { auto rhs_str = set_rhs_str(i, r2.size()); if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { @@ -462,8 +477,14 @@ namespace quda { } } - void Solver::PrintSummary(const char *name, int k, cvector &r2, cvector &b2, - cvector &r2_tol, cvector &hq_tol) { + void Solver::PrintSummary(const char *name, int k, cvector &r2, cvector &b2, cvector &r2_tol, + cvector &hq_tol_) + { + auto hq_tol = hq_tol_.size() == 0 ? vector(r2.size(), 0.0) : hq_tol_; + if (r2.size() != b2.size() || r2.size() != r2_tol.size() || r2.size() != hq_tol.size()) + errorQuda("Mismatched vector lengths r2 = %lu b2 = %lu r2_tol = %lu hq_tol = %lu", r2.size(), b2.size(), + r2_tol.size(), hq_tol.size()); + for (auto i = 0u; i < r2.size(); i++) { auto rhs_str = set_rhs_str(i, r2.size()); if (param.compute_true_res) { From c7f05b5074d95f940f206d9d6e50223992973570 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Thu, 26 Sep 2024 15:06:45 -0700 Subject: [PATCH 079/103] Fix default asan options which broke when the separate test library was created --- tests/asan.h | 14 ++++++++++++++ tests/blas_test.cpp | 2 ++ tests/clover_force_test.cpp | 1 + tests/contract_ft_test.cpp | 1 + tests/dslash_test_utils.h | 1 + tests/eigensolve_test.cpp | 1 + tests/gauge_path_test.cpp | 1 + tests/hisq_paths_force_test.cpp | 1 + tests/hisq_stencil_ctest.cpp | 1 + tests/hisq_stencil_test.cpp | 1 + tests/hisq_unitarize_force_test.cpp | 1 + tests/invert_test.cpp | 1 + tests/staggered_dslash_test_utils.h | 1 + tests/staggered_eigensolve_test.cpp | 1 + tests/staggered_gsmear_test.cpp | 1 + tests/staggered_invert_test.cpp | 1 + tests/test.h | 1 + tests/unitarize_link_test.cpp | 1 + tests/utils/host_utils.cpp | 10 ---------- 19 files changed, 32 insertions(+), 10 deletions(-) create mode 100644 tests/asan.h diff --git a/tests/asan.h b/tests/asan.h new file mode 100644 index 0000000000..a56732248f --- /dev/null +++ b/tests/asan.h @@ -0,0 +1,14 @@ +#pragma once + +extern "C" { + + /** + @brief Set the default ASAN options. This ensures that QUDA just + works when SANITIZE is enabled without requiring ASAN_OPTIONS to + be set. We default disable leak checking, otherwise this will + cause ctest to fail with MPI library leaks. This declaration + cannot be in the test library, and must be in the test executable. + */ + const char *__asan_default_options() { return "detect_leaks=0,protect_shadow_gap=0"; } + +} diff --git a/tests/blas_test.cpp b/tests/blas_test.cpp index 3a62f4ac9e..0bb6980c54 100644 --- a/tests/blas_test.cpp +++ b/tests/blas_test.cpp @@ -17,6 +17,8 @@ // google test #include +#include "test.h" + using namespace quda; /** diff --git a/tests/clover_force_test.cpp b/tests/clover_force_test.cpp index 7488b0856d..4a9edb57b2 100644 --- a/tests/clover_force_test.cpp +++ b/tests/clover_force_test.cpp @@ -4,6 +4,7 @@ #include "clover_force_reference.h" #include "misc.h" +#include "test.h" #include // convenient quark field container #include #include diff --git a/tests/contract_ft_test.cpp b/tests/contract_ft_test.cpp index 552a0527db..3ee1f6adb3 100644 --- a/tests/contract_ft_test.cpp +++ b/tests/contract_ft_test.cpp @@ -10,6 +10,7 @@ #include #include #include "misc.h" +#include "test.h" // google test #include diff --git a/tests/dslash_test_utils.h b/tests/dslash_test_utils.h index 837b0c91d3..01041fcf26 100644 --- a/tests/dslash_test_utils.h +++ b/tests/dslash_test_utils.h @@ -26,6 +26,7 @@ #include #include +#include "test.h" using namespace quda; diff --git a/tests/eigensolve_test.cpp b/tests/eigensolve_test.cpp index 5b063755cc..e325511e05 100644 --- a/tests/eigensolve_test.cpp +++ b/tests/eigensolve_test.cpp @@ -14,6 +14,7 @@ #include #include #include +#include "test.h" // Place params above "eigensolve_test_gtest.hpp" so they // are visible therein. diff --git a/tests/gauge_path_test.cpp b/tests/gauge_path_test.cpp index fee31cb53c..1efa49ef97 100644 --- a/tests/gauge_path_test.cpp +++ b/tests/gauge_path_test.cpp @@ -12,6 +12,7 @@ #include #include #include +#include "test.h" static QudaGaugeFieldOrder gauge_order = QUDA_QDP_GAUGE_ORDER; diff --git a/tests/hisq_paths_force_test.cpp b/tests/hisq_paths_force_test.cpp index 7560dbc105..d696ca1558 100644 --- a/tests/hisq_paths_force_test.cpp +++ b/tests/hisq_paths_force_test.cpp @@ -7,6 +7,7 @@ #include #include "gauge_field.h" #include "misc.h" +#include "test.h" #include "hisq_force_reference.h" #include "ks_improved_force.h" #include "momentum.h" diff --git a/tests/hisq_stencil_ctest.cpp b/tests/hisq_stencil_ctest.cpp index 6df6d1d977..811032ded2 100644 --- a/tests/hisq_stencil_ctest.cpp +++ b/tests/hisq_stencil_ctest.cpp @@ -1,3 +1,4 @@ +#include "test.h" #include "hisq_stencil_test_utils.h" using namespace quda; diff --git a/tests/hisq_stencil_test.cpp b/tests/hisq_stencil_test.cpp index 3b20287d3b..cc6581733d 100644 --- a/tests/hisq_stencil_test.cpp +++ b/tests/hisq_stencil_test.cpp @@ -1,3 +1,4 @@ +#include "test.h" #include "hisq_stencil_test_utils.h" using namespace quda; diff --git a/tests/hisq_unitarize_force_test.cpp b/tests/hisq_unitarize_force_test.cpp index 01e3c78c18..04e1576f20 100644 --- a/tests/hisq_unitarize_force_test.cpp +++ b/tests/hisq_unitarize_force_test.cpp @@ -7,6 +7,7 @@ #include #include "gauge_field.h" #include "misc.h" +#include "test.h" #include "hisq_force_reference.h" #include "ks_improved_force.h" #include diff --git a/tests/invert_test.cpp b/tests/invert_test.cpp index 92b68e71e9..9de3ddb8b9 100644 --- a/tests/invert_test.cpp +++ b/tests/invert_test.cpp @@ -13,6 +13,7 @@ #include #include #include +#include "test.h" QudaGaugeParam gauge_param; QudaInvertParam inv_param; diff --git a/tests/staggered_dslash_test_utils.h b/tests/staggered_dslash_test_utils.h index 4d336d8d4f..a9b499d71a 100644 --- a/tests/staggered_dslash_test_utils.h +++ b/tests/staggered_dslash_test_utils.h @@ -20,6 +20,7 @@ #include #include #include +#include "test.h" using namespace quda; diff --git a/tests/staggered_eigensolve_test.cpp b/tests/staggered_eigensolve_test.cpp index 335367b478..93ae358896 100644 --- a/tests/staggered_eigensolve_test.cpp +++ b/tests/staggered_eigensolve_test.cpp @@ -14,6 +14,7 @@ #include #include #include +#include "test.h" QudaGaugeParam gauge_param; QudaInvertParam eig_inv_param; diff --git a/tests/staggered_gsmear_test.cpp b/tests/staggered_gsmear_test.cpp index e5ce29c346..84f793dcb9 100644 --- a/tests/staggered_gsmear_test.cpp +++ b/tests/staggered_gsmear_test.cpp @@ -1,3 +1,4 @@ +#include "test.h" #include "staggered_gsmear_test_utils.h" using namespace quda; diff --git a/tests/staggered_invert_test.cpp b/tests/staggered_invert_test.cpp index d512c5463e..812ff3c56f 100644 --- a/tests/staggered_invert_test.cpp +++ b/tests/staggered_invert_test.cpp @@ -17,6 +17,7 @@ #include #include #include +#include "test.h" QudaGaugeParam gauge_param; QudaInvertParam inv_param; diff --git a/tests/test.h b/tests/test.h index 7b726433a3..27552f0c10 100644 --- a/tests/test.h +++ b/tests/test.h @@ -2,6 +2,7 @@ #include #include #include +#include "asan.h" struct quda_test { diff --git a/tests/unitarize_link_test.cpp b/tests/unitarize_link_test.cpp index 4cd8553fdd..f6cf4c642d 100644 --- a/tests/unitarize_link_test.cpp +++ b/tests/unitarize_link_test.cpp @@ -9,6 +9,7 @@ #include "host_utils.h" #include #include "misc.h" +#include "test.h" #include "util_quda.h" #include "llfat_quda.h" #include diff --git a/tests/utils/host_utils.cpp b/tests/utils/host_utils.cpp index c7b0808180..bbeae80fe9 100644 --- a/tests/utils/host_utils.cpp +++ b/tests/utils/host_utils.cpp @@ -786,16 +786,6 @@ int fullLatticeIndex(int i, int oddBit) return X; } -extern "C" { -/** - @brief Set the default ASAN options. This ensures that QUDA just - works when SANITIZE is enabled without requiring ASAN_OPTIONS to be - set. We default disable leak checking, otherwise this will cause - ctest to fail with MPI library leaks. - */ -const char *__asan_default_options() { return "detect_leaks=0,protect_shadow_gap=0"; } -} - /** * For MPI, the default node mapping is lexicographical with t varying fastest. */ From 2f3faab30fb14a75889e236e937923dde5a53e1c Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Thu, 26 Sep 2024 16:03:22 -0700 Subject: [PATCH 080/103] Conditionally print energy information --- tests/invert_test.cpp | 22 ++++++++++++++-------- tests/staggered_invert_test.cpp | 22 ++++++++++++++-------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/tests/invert_test.cpp b/tests/invert_test.cpp index 9de3ddb8b9..60df482455 100644 --- a/tests/invert_test.cpp +++ b/tests/invert_test.cpp @@ -239,9 +239,11 @@ std::vector> solve(test_t param) printfQuda("MG Setup Done: %g secs, %g Gflops\n", mg_param.invert_param->secs, mg_param.invert_param->gflops / mg_param.invert_param->secs); - printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", - mg_param.invert_param->energy, mg_param.invert_param->power, - mg_param.invert_param->temp, mg_param.invert_param->clock); + if (mg_param.invert_param->energy > 0) { + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", + mg_param.invert_param->energy, mg_param.invert_param->power, + mg_param.invert_param->temp, mg_param.invert_param->clock); + } } // Vector construct START @@ -335,8 +337,10 @@ std::vector> solve(test_t param) iter[i] = inv_param.iter; printfQuda("Done: %i iter / %g secs = %g Gflops\n", inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs); - printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", - inv_param.energy, inv_param.power, inv_param.temp, inv_param.clock); + if (inv_param.energy > 0) { + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", + inv_param.energy, inv_param.power, inv_param.temp, inv_param.clock); + } } } else { @@ -373,9 +377,11 @@ std::vector> solve(test_t param) num_sub_partition, inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs, inv_param.secs / Nsrc_tile); - printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n", - inv_param.energy, inv_param.energy / Nsrc_tile, - inv_param.power, inv_param.temp, inv_param.clock); + if (inv_param.energy > 0) { + printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n", + inv_param.energy, inv_param.energy / Nsrc_tile, + inv_param.power, inv_param.temp, inv_param.clock); + } } } diff --git a/tests/staggered_invert_test.cpp b/tests/staggered_invert_test.cpp index 812ff3c56f..524270c270 100644 --- a/tests/staggered_invert_test.cpp +++ b/tests/staggered_invert_test.cpp @@ -285,9 +285,11 @@ std::vector> solve(test_t param) printfQuda("MG Setup Done: %g secs, %g Gflops\n", mg_param.invert_param->secs, mg_param.invert_param->gflops / mg_param.invert_param->secs); - printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", - mg_param.invert_param->energy, mg_param.invert_param->power, - mg_param.invert_param->temp, mg_param.invert_param->clock); + if (mg_param.invert_param->energy > 0) { + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", + mg_param.invert_param->energy, mg_param.invert_param->power, + mg_param.invert_param->temp, mg_param.invert_param->clock); + } } // Staggered vector construct START @@ -392,8 +394,10 @@ std::vector> solve(test_t param) iter[n] = inv_param.iter; printfQuda("Done: %i iter / %g secs = %g Gflops\n", inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs); - printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n\n", - inv_param.energy, inv_param.power, inv_param.temp, inv_param.clock); + if (inv_param.energy > 0) { + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n\n", + inv_param.energy, inv_param.power, inv_param.temp, inv_param.clock); + } } } else { @@ -427,9 +431,11 @@ std::vector> solve(test_t param) printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition, inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs, inv_param.secs / Nsrc_tile); - printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n\n", - inv_param.energy, inv_param.energy / Nsrc_tile, - inv_param.power, inv_param.temp, inv_param.clock); + if (inv_param.energy > 0) { + printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n\n", + inv_param.energy, inv_param.energy / Nsrc_tile, + inv_param.power, inv_param.temp, inv_param.clock); + } } } From 5af9d016c56bce7a469ad3c1325d0d0aa55133cb Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Thu, 26 Sep 2024 17:09:47 -0700 Subject: [PATCH 081/103] Reduce memory for clover force (use smaller halo for extended field --- lib/interface_quda.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 01d0035440..5e3b799f24 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -4595,6 +4595,8 @@ void computeCloverForceQuda(void *h_mom, double dt, void **h_x, void **, double // Make sure extendedGaugeResident has the correct R if (extendedGaugeResident) delete extendedGaugeResident; + lat_dim_t R; + for (int d=0; d<4; d++) R[d] = (d==0 ? 2 : 1) * (redundant_comms || commDimPartitioned(d)); extendedGaugeResident = createExtendedGauge(*gaugePrecise, R, getProfile()); GaugeField &gaugeEx = *extendedGaugeResident; @@ -4668,6 +4670,8 @@ void computeTMCloverForceQuda(void *h_mom, void **h_x, void **h_x0, double *coef // Make sure extendedGaugeResident has the correct R if (extendedGaugeResident) delete extendedGaugeResident; + lat_dim_t R; + for (int d=0; d<4; d++) R[d] = (d==0 ? 2 : 1) * (redundant_comms || commDimPartitioned(d)); extendedGaugeResident = createExtendedGauge(*gaugePrecise, R, profileTMCloverForce); GaugeField &gaugeEx = *extendedGaugeResident; From 764fb78c250094595abafb1c6f933a77240c6a04 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Thu, 26 Sep 2024 17:10:10 -0700 Subject: [PATCH 082/103] Fix compiler warning --- tests/contract_ft_test.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/contract_ft_test.cpp b/tests/contract_ft_test.cpp index 3ee1f6adb3..cc793ede5b 100644 --- a/tests/contract_ft_test.cpp +++ b/tests/contract_ft_test.cpp @@ -114,7 +114,6 @@ inline void fill_buffers(std::array, N> &buffs, const std::ar srand(l); for (int i = 0; i < dofs; i++) { -#pragma unroll for (int n = 0; n < N; n++) { buffs[n][ll * dofs + i] = 2. * (rand() / (Float)RAND_MAX) - 1.; } } } From 7523f3a5749840f0ff26ce1130713043cde2a656 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Thu, 26 Sep 2024 17:13:06 -0700 Subject: [PATCH 083/103] Fix bug in CA CG --- lib/inv_ca_cg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/inv_ca_cg.cpp b/lib/inv_ca_cg.cpp index 9afd8c37dd..96060b9e9c 100644 --- a/lib/inv_ca_cg.cpp +++ b/lib/inv_ca_cg.cpp @@ -535,7 +535,7 @@ namespace quda resIncreaseTotal++; warningQuda( "CA-CG: new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", - sqrt(r2[0]), sqrt(r2_old[9]), resIncreaseTotal); + sqrt(r2[0]), sqrt(r2_old[0]), resIncreaseTotal); if (resIncrease > maxResIncrease or resIncreaseTotal > maxResIncreaseTotal) { warningQuda("CA-CG: solver exiting due to too many true residual norm increases"); break; From e46b436d4341e1264d6ec39ccdd0da1bb02e51db Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 27 Sep 2024 16:54:43 -0700 Subject: [PATCH 084/103] Work arounds for NVSHMEM due to coarse grained synchronization used in uber kernel: do not use > 16 rhs with uber kernel; reduce x block size if using uber kernel. Also fixes a bug with shmem tuning: step_y/step_z should be reverted after uber tuning is complete --- include/dslash.h | 20 +++++++++----------- lib/dslash_policy.hpp | 8 ++++++-- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/include/dslash.h b/include/dslash.h index 2643fbbc4e..e51ea64b3e 100644 --- a/include/dslash.h +++ b/include/dslash.h @@ -156,8 +156,8 @@ namespace quda } } - virtual int blockStep() const override { return 16; } - virtual int blockMin() const override { return 16; } + virtual int blockStep() const override { return (arg.shmem & 64) ? 8 : 16; } + virtual int blockMin() const override { return (arg.shmem & 64) ? 8 : 16; } unsigned int maxSharedBytesPerBlock() const override { return maxDynamicSharedBytesPerBlock(); } @@ -207,13 +207,12 @@ namespace quda virtual void initTuneParam(TuneParam ¶m) const override { - /* for nvshmem uber kernels the current synchronization requires use to keep the y and z dimension local to the + /* for nvshmem uber kernels the current synchronization requires us to keep the y and z dimension local to the * block. This can be removed when we introduce a finer grained synchronization which takes into account the y and * z components explicitly */ - if (arg.shmem & 64) { - step_y = vector_length_y; - step_z = vector_length_z; - } + step_y = arg.shmem & 64 ? vector_length_y : 1; + step_z = arg.shmem & 64 ? vector_length_z : 1; + TunableKernel3D::initTuneParam(param); if (arg.pack_threads && (arg.kernel_type == INTERIOR_KERNEL || arg.kernel_type == UBER_KERNEL)) param.aux.x = 1; // packing blocks per direction @@ -225,10 +224,9 @@ namespace quda /* for nvshmem uber kernels the current synchronization requires use to keep the y and z dimension local to the * block. This can be removed when we introduce a finer grained synchronization which takes into account the y and * z components explicitly. */ - if (arg.shmem & 64) { - step_y = vector_length_y; - step_z = vector_length_z; - } + step_y = arg.shmem & 64 ? vector_length_y : 1; + step_z = arg.shmem & 64 ? vector_length_z : 1; + TunableKernel3D::defaultTuneParam(param); if (arg.pack_threads && (arg.kernel_type == INTERIOR_KERNEL || arg.kernel_type == UBER_KERNEL)) param.aux.x = 1; // packing blocks per direction diff --git a/lib/dslash_policy.hpp b/lib/dslash_policy.hpp index dce21f1be3..f7d024085f 100644 --- a/lib/dslash_policy.hpp +++ b/lib/dslash_policy.hpp @@ -1859,8 +1859,12 @@ namespace quda } if (comm_nvshmem_enabled()) { - enable_policy(QudaDslashPolicy::QUDA_SHMEM_UBER_PACKINTRA_DSLASH); - enable_policy(QudaDslashPolicy::QUDA_SHMEM_UBER_PACKFULL_DSLASH); + if (in.size() <= 16) { + // FIXME until uber dslash gets fine-grained + // synchronization, we cannot use it with large RHS + enable_policy(QudaDslashPolicy::QUDA_SHMEM_UBER_PACKINTRA_DSLASH); + enable_policy(QudaDslashPolicy::QUDA_SHMEM_UBER_PACKFULL_DSLASH); + } enable_policy(QudaDslashPolicy::QUDA_SHMEM_PACKINTRA_DSLASH); enable_policy(QudaDslashPolicy::QUDA_SHMEM_PACKFULL_DSLASH); } From 093e0e6fac172c92cc62869739eeb67f5e152aba Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 1 Oct 2024 10:29:45 -0700 Subject: [PATCH 085/103] Split grid true residuals now correctly returned in QudaInvertParam struct --- include/invert_quda.h | 47 ++++++++++++----------------- include/split_grid.h | 2 +- lib/interface_quda.cpp | 5 ++++ lib/solver.cpp | 68 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 93 insertions(+), 29 deletions(-) diff --git a/include/invert_quda.h b/include/invert_quda.h index 891a5fa947..c235a98192 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -355,35 +355,12 @@ namespace quda { SolverParam(const SolverParam ¶m) = default; /** - Update the QudaInvertParam with the data from this - @param param the QudaInvertParam to be updated + @brief Update the QudaInvertParam with the data from this + instance (update the true residuals, and other observables). + @param[in,out] param the QudaInvertParam to be updated + @param[in] offset offset applied to the */ - void updateInvertParam(QudaInvertParam ¶m, int offset=-1) { - for (auto i = 0u; i < true_res.size(); i++) param.true_res[i] = true_res[i]; - for (auto i = 0u; i < true_res_hq.size(); i++) param.true_res_hq[i] = true_res_hq[i]; - param.iter += iter; - if (offset >= 0) { - param.true_res_offset[offset] = true_res_offset[offset]; - param.iter_res_offset[offset] = iter_res_offset[offset]; - param.true_res_hq_offset[offset] = true_res_hq_offset[offset]; - } else { - for (int i=0; i(param.eig_param) = eig_param; - } + void updateInvertParam(QudaInvertParam ¶m, int offset = -1); // for incremental eigCG: void updateRhsIndex(QudaInvertParam ¶m) { rhs_idx = param.rhs_idx; } @@ -1675,4 +1652,18 @@ namespace quda { */ bool is_ca_solver(QudaInverterType type); + /** + @brief Join the separate split-grid instances of + QudaInvertParam. This function places the computed residuals + for each solve from the split grids in the expected order. + This function expects we are using the default (global) + communuicator. + + @param[in, out] out The global joined instance of QudaInvertParam + @param[in] in The local split-grid instance of QudaInvertParam + @param[in] comm_key The CommKey that defines the split grid used + @param[in] split_rank The rank of the process when in split grid + */ + void joinInvertParam(QudaInvertParam &out, const QudaInvertParam &in, const CommKey &comm_key, int split_rank); + } // namespace quda diff --git a/include/split_grid.h b/include/split_grid.h index 19a37c4311..2bd42bb476 100644 --- a/include/split_grid.h +++ b/include/split_grid.h @@ -37,7 +37,7 @@ namespace quda things and the extension to 4d is trivial. */ - auto processor_dim = comm_grid_dim / comm_key; // How many processors are there in a processor grid sub-parititon? + auto processor_dim = comm_grid_dim / comm_key; // How many processors are there in a processor grid sub-partition? auto partition_dim = comm_grid_dim / processor_dim; // How many such sub-partitions are there? partition_dim == comm_key diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 5e3b799f24..6939e8be34 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -3312,11 +3312,16 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col for (auto i = 0u; i < b_raw.size(); i++) b_raw[i] = _collect_b[i].data(); op(x_raw, b_raw, param_copy, args...); + auto split_rank = comm_rank(); + profileInvertMultiSrc.TPSTART(QUDA_PROFILE_EPILOGUE); push_communicator(default_comm_key); updateR(); comm_barrier(); + // back to the default communicator, now join the param entries + joinInvertParam(*param, param_copy, split_key, split_rank); + // Join spinors: _h_x are aliases to host pointers in 'external order: QDP++, QDP-JIT, etc' for (int n = 0; n < param->num_src_per_sub_partition; n++) { // join fields diff --git a/lib/solver.cpp b/lib/solver.cpp index 8fc833beba..877c7c1ac8 100644 --- a/lib/solver.cpp +++ b/lib/solver.cpp @@ -585,4 +585,72 @@ namespace quda { return zero_src; } + void SolverParam::updateInvertParam(QudaInvertParam ¶m, int offset) + { + for (auto i = 0u; i < true_res.size(); i++) param.true_res[i] = true_res[i]; + for (auto i = 0u; i < true_res_hq.size(); i++) param.true_res_hq[i] = true_res_hq[i]; + param.iter += iter; + if (offset >= 0) { + param.true_res_offset[offset] = true_res_offset[offset]; + param.iter_res_offset[offset] = iter_res_offset[offset]; + param.true_res_hq_offset[offset] = true_res_hq_offset[offset]; + } else { + for (int i = 0; i < num_offset; i++) { + param.true_res_offset[i] = true_res_offset[i]; + param.iter_res_offset[i] = iter_res_offset[i]; + param.true_res_hq_offset[i] = true_res_hq_offset[i]; + } + } + // for incremental eigCG: + param.rhs_idx = rhs_idx; + + param.ca_lambda_min = ca_lambda_min; + param.ca_lambda_max = ca_lambda_max; + + param.ca_lambda_min_precondition = ca_lambda_min_precondition; + param.ca_lambda_max_precondition = ca_lambda_max_precondition; + + if (deflate) *static_cast(param.eig_param) = eig_param; + } + + void joinInvertParam(QudaInvertParam &out, const QudaInvertParam &in, const CommKey &split_key, int split_rank) + { + auto num_sub_partition = quda::product(split_key); + + int sub_partition_dims[] + = {comm_dim(0) / split_key[0], comm_dim(1) / split_key[1], comm_dim(2) / split_key[2], comm_dim(3) / split_key[3]}; + + int sub_partition_coords[] = {comm_coord(0) / sub_partition_dims[0], comm_coord(1) / sub_partition_dims[1], + comm_coord(2) / sub_partition_dims[2], comm_coord(3) / sub_partition_dims[3]}; + + auto j = sub_partition_coords[3]; + for (auto d = 2; d >= 0; d--) j = j * split_key[d] + sub_partition_coords[d]; + + std::vector true_res(in.num_src, 0.0); + std::vector true_res_hq(in.num_src, 0.0); + if (split_rank == 0) { // only rank 0 in each sub partition sets the residuals + for (auto i = 0; i < in.num_src_per_sub_partition; i++) { + true_res[i * num_sub_partition + j] = in.true_res[i]; + true_res_hq[i * num_sub_partition + j] = in.true_res_hq[i]; + } + } + + // communicate to all ranks + comm_allreduce_sum(true_res_hq); + comm_allreduce_sum(true_res); + memcpy(out.true_res, true_res.data(), true_res.size() * sizeof(double)); + memcpy(out.true_res_hq, true_res_hq.data(), true_res_hq.size() * sizeof(double)); + + out.iter = in.iter; + comm_allreduce_int(out.iter); + + out.ca_lambda_min = in.ca_lambda_min; + out.ca_lambda_max = in.ca_lambda_max; + out.ca_lambda_min_precondition = in.ca_lambda_min_precondition; + out.ca_lambda_max_precondition = in.ca_lambda_max_precondition; + + // now broadcast from global rank 0 to ensure uniformity + comm_broadcast(&out, sizeof(QudaInvertParam)); + } + } // namespace quda From a3e889d4128ea58b8afe0b89f2c69e887e9f1b09 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 1 Oct 2024 14:35:08 -0700 Subject: [PATCH 086/103] invert_test should never call invertMultiSrcQuda is multishift is intended --- tests/invert_test.cpp | 2 +- tests/staggered_invert_test.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/invert_test.cpp b/tests/invert_test.cpp index 60df482455..7005075d9b 100644 --- a/tests/invert_test.cpp +++ b/tests/invert_test.cpp @@ -316,7 +316,7 @@ std::vector> solve(test_t param) verifySpinorDistanceReweight(in[0], distance_pc_alpha0, distance_pc_t0); } - if (!use_multi_src) { + if (!use_multi_src || multishift > 1) { for (int i = 0; i < Nsrc; i++) { // If deflating, preserve the deflation space between solves diff --git a/tests/staggered_invert_test.cpp b/tests/staggered_invert_test.cpp index 524270c270..7f1c8dd053 100644 --- a/tests/staggered_invert_test.cpp +++ b/tests/staggered_invert_test.cpp @@ -373,7 +373,7 @@ std::vector> solve(test_t param) // QUDA invert test //---------------------------------------------------------------------------- - if (!use_multi_src) { + if (!use_multi_src || multishift > 1) { for (int n = 0; n < Nsrc; n++) { // If deflating, preserve the deflation space between solves From 9933e4d112d0079f78f4e868e2cbc41852894cf1 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 1 Oct 2024 15:04:41 -0700 Subject: [PATCH 087/103] Some MRHS HQ related solver fixes --- lib/inv_bicgstab_quda.cpp | 9 +++++---- lib/inv_bicgstabl_quda.cpp | 4 ++-- lib/inv_ca_cg.cpp | 2 +- lib/inv_ca_gcr.cpp | 4 ++-- lib/inv_gcr_quda.cpp | 7 ++++--- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/lib/inv_bicgstab_quda.cpp b/lib/inv_bicgstab_quda.cpp index 73e3f9868e..be5c1583ff 100644 --- a/lib/inv_bicgstab_quda.cpp +++ b/lib/inv_bicgstab_quda.cpp @@ -151,6 +151,7 @@ namespace quda { getProfile().TPSTART(QUDA_PROFILE_PREAMBLE); auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop_hq = std::vector(b.size(), param.tol_hq); const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; @@ -187,7 +188,7 @@ namespace quda { rho = r2; // cDotProductCuda(r0, r_sloppy); // BiCRstab blas::copy(p, r_sloppy); - bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + bool converged = convergence(r2, heavy_quark_res, stop, stop_hq); // track if we just performed an exact recalculation of y, r, r2 bool just_updated = false; @@ -299,7 +300,7 @@ namespace quda { k++; PrintStats("BiCGstab", k, r2, b2, heavy_quark_res); - converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + converged = convergence(r2, heavy_quark_res, stop, stop_hq); if (converged) { // make sure we've truly converged @@ -336,7 +337,7 @@ namespace quda { } // Update convergence check - converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + converged = convergence(r2, heavy_quark_res, stop, stop_hq); } // update p @@ -374,7 +375,7 @@ namespace quda { param.true_res[i] = sqrt(r2[i] / b2[i]); param.true_res_hq[i] = sqrt(hq[i].z); } - PrintSummary("BiCGstab", k, r2, b2, stop, param.tol_hq); + PrintSummary("BiCGstab", k, r2, b2, stop, stop_hq); } getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); diff --git a/lib/inv_bicgstabl_quda.cpp b/lib/inv_bicgstabl_quda.cpp index 42f5b50fbb..98176af28a 100644 --- a/lib/inv_bicgstabl_quda.cpp +++ b/lib/inv_bicgstabl_quda.cpp @@ -603,7 +603,7 @@ namespace quda { double maxrx = rNorm; // The same. Would be different if we did 'x' reliable updates. PrintStats(solver_name.c_str(), total_iter, r2, b2, heavy_quark_res); - while (!convergence(r2, 0.0, stop, 0.0) && total_iter < param.maxiter) { + while (!convergenceL2(r2, stop) && total_iter < param.maxiter) { // rho0 = -omega*rho0; for (auto i = 0u; i < b.size(); i++) rho0[i] *= -omega[i]; @@ -754,7 +754,7 @@ namespace quda { getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); - PrintSummary(solver_name.c_str(), total_iter, r2, b2, stop, param.tol_hq); + PrintSummary(solver_name.c_str(), total_iter, r2, b2, stop); } } // namespace quda diff --git a/lib/inv_ca_cg.cpp b/lib/inv_ca_cg.cpp index 96060b9e9c..13544a6a1e 100644 --- a/lib/inv_ca_cg.cpp +++ b/lib/inv_ca_cg.cpp @@ -394,7 +394,7 @@ namespace quda blas::copy(S[0], r); // no op if uni-precision PrintStats("CA-CG", total_iter, r2, b2, heavy_quark_res); - while (!convergence(r2, heavy_quark_res, stop, param.tol_hq) && total_iter < param.maxiter) { + while (!convergenceL2(r2, stop) && total_iter < param.maxiter) { // build up a space of size n_krylov, assumes S[0] is in place computeCAKrylovSpace(matSloppy, AS, S, n_krylov, basis, m_map, b_map); diff --git a/lib/inv_ca_gcr.cpp b/lib/inv_ca_gcr.cpp index 3a74acb00b..c755275bf6 100644 --- a/lib/inv_ca_gcr.cpp +++ b/lib/inv_ca_gcr.cpp @@ -282,7 +282,7 @@ namespace quda }; PrintStats("CA-GCR", total_iter, r2, b2, heavy_quark_res); - while (!convergence(r2, heavy_quark_res, stop, param.tol_hq) && total_iter < param.maxiter) { + while (!convergenceL2(r2, stop) && total_iter < param.maxiter) { // build up a space of size n_krylov computeCAKrylovSpace(matSloppy, q, p, n_krylov, basis, m_map, b_map); @@ -356,7 +356,7 @@ namespace quda } // No matter what, if we haven't converged, we do a restart. - if (!convergence(r2, heavy_quark_res, stop, param.tol_hq)) { + if (!convergenceL2(r2, stop)) { restart++; // restarting if residual is still too great PrintStats("CA-GCR (restart)", restart, r2, b2, heavy_quark_res); diff --git a/lib/inv_gcr_quda.cpp b/lib/inv_gcr_quda.cpp index d7df9eb4fc..d0d8031413 100644 --- a/lib/inv_gcr_quda.cpp +++ b/lib/inv_gcr_quda.cpp @@ -261,6 +261,7 @@ namespace quda { } auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop_hq = vector(b.size(), param.tol_hq); const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; @@ -301,7 +302,7 @@ namespace quda { k_break = 0; PrintStats("GCR", total_iter+k, r2, b2, heavy_quark_res); - while ( !convergence(r2, heavy_quark_res, stop, param.tol_hq) && total_iter < param.maxiter) { + while ( !convergence(r2, heavy_quark_res, stop, stop_hq) && total_iter < param.maxiter) { if (K) { pushVerbosity(param.verbosity_precondition); @@ -388,7 +389,7 @@ namespace quda { k_break = k; k = 0; - if ( !convergence(r2, heavy_quark_res, stop, param.tol_hq) ) { + if ( !convergence(r2, heavy_quark_res, stop, stop_hq) ) { restart++; // restarting if residual is still too great PrintStats("GCR (restart)", restart, r2, b2, heavy_quark_res); @@ -426,7 +427,7 @@ namespace quda { getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); - PrintSummary("GCR", total_iter, r2, b2, stop, param.tol_hq); + PrintSummary("GCR", total_iter, r2, b2, stop, stop_hq); } } // namespace quda From da38153b5f0a8b78b144cf000d92af9147a82b53 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 1 Oct 2024 15:14:47 -0700 Subject: [PATCH 088/103] More MRHS solver fixes --- lib/inv_ca_cg.cpp | 2 +- lib/inv_ca_gcr.cpp | 2 +- lib/inv_cgne.cpp | 2 +- lib/inv_cgnr.cpp | 2 +- lib/inv_mr_quda.cpp | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/inv_ca_cg.cpp b/lib/inv_ca_cg.cpp index 13544a6a1e..86b6429532 100644 --- a/lib/inv_ca_cg.cpp +++ b/lib/inv_ca_cg.cpp @@ -570,7 +570,7 @@ namespace quda param.iter += total_iter; } - PrintSummary("CA-CG", total_iter, r2, b2, stop, param.tol_hq); + PrintSummary("CA-CG", total_iter, r2, b2, stop); if (param.is_preconditioner) commGlobalReductionPop(); } diff --git a/lib/inv_ca_gcr.cpp b/lib/inv_ca_gcr.cpp index c755275bf6..aa41c9409b 100644 --- a/lib/inv_ca_gcr.cpp +++ b/lib/inv_ca_gcr.cpp @@ -390,7 +390,7 @@ namespace quda param.iter += total_iter; } - PrintSummary("CA-GCR", total_iter, r2, b2, stop, param.tol_hq); + PrintSummary("CA-GCR", total_iter, r2, b2, stop); if (param.is_preconditioner) commGlobalReductionPop(); } diff --git a/lib/inv_cgne.cpp b/lib/inv_cgne.cpp index 9e7f19c97f..12fe55ec95 100644 --- a/lib/inv_cgne.cpp +++ b/lib/inv_cgne.cpp @@ -96,7 +96,7 @@ namespace quda r2 = blas::norm2(xe); } for (auto i = 0u; i < b.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); - PrintSummary("CGNE", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); + PrintSummary("CGNE", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type)); } } diff --git a/lib/inv_cgnr.cpp b/lib/inv_cgnr.cpp index 9dc7b67e0a..e143ad1187 100644 --- a/lib/inv_cgnr.cpp +++ b/lib/inv_cgnr.cpp @@ -82,7 +82,7 @@ namespace quda r2 = blas::norm2(br); } for (auto i = 0u; i < b.size(); i++) param.true_res[i] = sqrt(r2[i] / b2[i]); - PrintSummary("CGNR", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); + PrintSummary("CGNR", param.iter - iter0, r2, b2, stopping(param.tol, b2, param.residual_type)); } } } diff --git a/lib/inv_mr_quda.cpp b/lib/inv_mr_quda.cpp index 566ff5661e..d8565286fa 100644 --- a/lib/inv_mr_quda.cpp +++ b/lib/inv_mr_quda.cpp @@ -159,7 +159,7 @@ namespace quda step++; } - PrintSummary("MR", iter, r2, b2, stopping(param.tol, b2, param.residual_type), param.tol_hq); + PrintSummary("MR", iter, r2, b2, stopping(param.tol, b2, param.residual_type)); if (!param.is_preconditioner) { getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); From e5d74e4b8d41c631ed4d537d3716a3785b33a6f3 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 1 Oct 2024 15:15:45 -0700 Subject: [PATCH 089/103] Add QudaInvertParam::energy/power/temp/clock to check_param --- lib/check_params.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lib/check_params.h b/lib/check_params.h index 7d7b694eff..266f12d727 100644 --- a/lib/check_params.h +++ b/lib/check_params.h @@ -675,10 +675,18 @@ void printQudaInvertParam(QudaInvertParam *param) { P(iter, 0); P(gflops, 0.0); P(secs, 0.0); + P(energy, 0.0); + P(power, 0.0); + P(temp, 0.0); + P(clock, 0.0); #elif defined(PRINT_PARAM) P(iter, INVALID_INT); P(gflops, INVALID_DOUBLE); P(secs, INVALID_DOUBLE); + P(energy, INVALID_DOUBLE); + P(power, INVALID_DOUBLE); + P(temp, INVALID_DOUBLE); + P(clock, INVALID_DOUBLE); #endif From 3bfcc688845c29541d97b02fad90624d11b91175 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 1 Oct 2024 15:16:27 -0700 Subject: [PATCH 090/103] Wilson ctest invert_test now uses multi-RHS --- tests/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 823633a408..efe0886030 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1004,7 +1004,7 @@ endif() if(QUDA_DIRAC_WILSON) add_test(NAME invert_test_wilson COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} - --dslash-type wilson --ngcrkrylov 8 + --dslash-type wilson --ngcrkrylov 8 --nsrc 4 --nsrc-tile 4 --dim 2 4 6 8 --niter 1000 --enable-testing true --gtest_output=xml:invert_test_wilson.xml) From d3217adf1509e9c78dedb4420e7a40c09259416c Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 1 Oct 2024 16:05:22 -0700 Subject: [PATCH 091/103] Fix memory freeing with chrono predictor --- lib/interface_quda.cpp | 17 ++++------------- lib/solve.cpp | 11 +++++++++++ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 6939e8be34..9e1fc8b2a0 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -115,11 +115,6 @@ namespace quda } -// vector of spinors used for forecasting solutions in HMC -#define QUDA_MAX_CHRONO 12 -// each entry is one p -std::vector> chronoResident(QUDA_MAX_CHRONO); - // Mapped memory buffer used to hold unitarization failures static int *num_failures_h = nullptr; static int *num_failures_d = nullptr; @@ -303,6 +298,8 @@ static void profilerStop(const char *f) { namespace quda { void printLaunchTimer(); + void flushChrono(int i = -1); + void massRescale(cvector_ref &b, QudaInvertParam ¶m, bool for_multishift); void distanceReweight(cvector_ref &b, QudaInvertParam ¶m, bool inverse); @@ -1334,13 +1331,7 @@ void freeCloverQuda(void) cloverPrecise = nullptr; } -void flushChronoQuda(int i) -{ - if (i >= QUDA_MAX_CHRONO) - errorQuda("Requested chrono index %d is outside of max %d\n", i, QUDA_MAX_CHRONO); - - chronoResident[i].clear(); -} +void flushChronoQuda(int i) { flushChrono(i); } void endQuda(void) { @@ -1352,7 +1343,7 @@ void endQuda(void) freeGaugeQuda(); freeCloverQuda(); - for (int i = 0; i < QUDA_MAX_CHRONO; i++) flushChronoQuda(i); + flushChrono(); solutionResident.clear(); momResident = GaugeField(); diff --git a/lib/solve.cpp b/lib/solve.cpp index 938e3cfecd..f06d5b6bd6 100644 --- a/lib/solve.cpp +++ b/lib/solve.cpp @@ -8,6 +8,17 @@ namespace quda // each entry is one p std::vector> chronoResident(QUDA_MAX_CHRONO); + void flushChrono(int i) + { + if (i >= QUDA_MAX_CHRONO) + errorQuda("Requested chrono index %d is outside of max %d", i, QUDA_MAX_CHRONO); + + if (i >= 0) + chronoResident[i].clear(); + else + for (auto i = 0; i < QUDA_MAX_CHRONO; i++) chronoResident[i].clear(); + } + void massRescale(cvector_ref &b, QudaInvertParam ¶m, bool for_multishift) { double kappa5 = (0.5 / (5.0 + param.m5)); From 1e2822165b7331be890180e0f57ca4eddc7305d2 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 1 Oct 2024 22:27:00 -0700 Subject: [PATCH 092/103] Fix CG3 for MRHS --- lib/inv_cg3_quda.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/lib/inv_cg3_quda.cpp b/lib/inv_cg3_quda.cpp index e19fd508bc..54de943b22 100644 --- a/lib/inv_cg3_quda.cpp +++ b/lib/inv_cg3_quda.cpp @@ -68,6 +68,7 @@ namespace quda { create(x, b); auto stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + auto stop_hq = vector(b.size(), param.tol_hq); const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; @@ -117,7 +118,7 @@ namespace quda { } getProfile().TPSTOP(QUDA_PROFILE_PREAMBLE); - if (convergence(r2, heavy_quark_res, stop, param.tol_hq)) return; + if (convergence(r2, heavy_quark_res, stop, stop_hq)) return; getProfile().TPSTART(QUDA_PROFILE_COMPUTE); auto r2_old = r2; @@ -133,7 +134,7 @@ namespace quda { vector rho(b.size(), 1.0); vector gamma(b.size(), 1.0); - while ( !convergence(r2, heavy_quark_res, stop, param.tol_hq) && k < param.maxiter) { + while ( !convergence(r2, heavy_quark_res, stop, stop_hq) && k < param.maxiter) { matSloppy(ArS, rS); auto gamma_old = gamma; @@ -173,10 +174,10 @@ namespace quda { update = ( update || (rNorm < delta*maxrr && r0Norm <= maxrr)); // condition for r // force a reliable update if we are within target tolerance (only if doing reliable updates) - if (convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol) update = true; + if (convergence(r2, heavy_quark_res, stop, stop_hq) && param.delta >= param.tol) update = true; // For heavy-quark inversion force a reliable update if we continue after - if ( use_heavy_quark_res and L2breakdown and convergenceHQ(heavy_quark_res, param.tol_hq) and param.delta >= param.tol ) { + if ( use_heavy_quark_res and L2breakdown and convergenceHQ(heavy_quark_res, stop_hq) and param.delta >= param.tol ) { update = true; } @@ -197,7 +198,7 @@ namespace quda { maxrr = rNorm; maxrx = rNorm; // we update sloppy and old fields - if (!convergence(r2, heavy_quark_res, stop, param.tol_hq)) { + if (!convergence(r2, heavy_quark_res, stop, stop_hq)) { blas::copy(rS, r); blas::axpy(-1., xS, xS_old); // we preserve the orthogonality between the previous residual and the new @@ -246,12 +247,12 @@ namespace quda { } } } else { - if (convergence(r2, heavy_quark_res, stop, param.tol_hq)) { + if (convergence(r2, heavy_quark_res, stop, stop_hq)) { mat(r, x); r2 = blas::xmyNorm(b, r); r0Norm = sqrt(r2[0]); // we update sloppy and old fields - if (!convergence(r2, heavy_quark_res, stop, param.tol_hq)) { + if (!convergence(r2, heavy_quark_res, stop, stop_hq)) { // we preserve the orthogonality between the previous residual and the new auto rr_old = blas::cDotProduct(rS, rS_old); for (auto i = 0u; i < r2.size(); i++) rr_old[i] /= r2[i]; @@ -295,7 +296,7 @@ namespace quda { } } - PrintSummary("CG3", k, r2, b2, stop, param.tol_hq); + PrintSummary("CG3", k, r2, b2, stop, stop_hq); getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); } From 9936354da5066bdb5ab07582e800810e140bc5f5 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 2 Oct 2024 10:59:38 -0700 Subject: [PATCH 093/103] Fix clover force test --- include/dslash.h | 8 ++++---- include/tunable_nd.h | 23 ++++++++++++++++++++--- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/include/dslash.h b/include/dslash.h index e51ea64b3e..4b1aada52e 100644 --- a/include/dslash.h +++ b/include/dslash.h @@ -210,8 +210,8 @@ namespace quda /* for nvshmem uber kernels the current synchronization requires us to keep the y and z dimension local to the * block. This can be removed when we introduce a finer grained synchronization which takes into account the y and * z components explicitly */ - step_y = arg.shmem & 64 ? vector_length_y : 1; - step_z = arg.shmem & 64 ? vector_length_z : 1; + step_y = arg.shmem & 64 ? vector_length_y : step_y_bkup; + step_z = arg.shmem & 64 ? vector_length_z : step_z_bkup; TunableKernel3D::initTuneParam(param); if (arg.pack_threads && (arg.kernel_type == INTERIOR_KERNEL || arg.kernel_type == UBER_KERNEL)) @@ -224,8 +224,8 @@ namespace quda /* for nvshmem uber kernels the current synchronization requires use to keep the y and z dimension local to the * block. This can be removed when we introduce a finer grained synchronization which takes into account the y and * z components explicitly. */ - step_y = arg.shmem & 64 ? vector_length_y : 1; - step_z = arg.shmem & 64 ? vector_length_z : 1; + step_y = arg.shmem & 64 ? vector_length_y : step_y_bkup; + step_z = arg.shmem & 64 ? vector_length_z : step_z_bkup; TunableKernel3D::defaultTuneParam(param); if (arg.pack_threads && (arg.kernel_type == INTERIOR_KERNEL || arg.kernel_type == UBER_KERNEL)) diff --git a/include/tunable_nd.h b/include/tunable_nd.h index b942bc041d..f7076f9dff 100644 --- a/include/tunable_nd.h +++ b/include/tunable_nd.h @@ -169,6 +169,7 @@ namespace quda protected: mutable unsigned int vector_length_y; mutable unsigned int step_y; + mutable unsigned int step_y_bkup; bool tune_block_x; /** @@ -231,7 +232,11 @@ namespace quda */ TunableKernel2D_base(const LatticeField &field, unsigned int vector_length_y, QudaFieldLocation location = QUDA_INVALID_FIELD_LOCATION) : - TunableKernel1D_base(field, location), vector_length_y(vector_length_y), step_y(1), tune_block_x(true) + TunableKernel1D_base(field, location), + vector_length_y(vector_length_y), + step_y(1), + step_y_bkup(1), + tune_block_x(true) { } @@ -242,7 +247,11 @@ namespace quda @param[in] location Location where the calculation will take place */ TunableKernel2D_base(size_t n_items, unsigned int vector_length_y, QudaFieldLocation location) : - TunableKernel1D_base(n_items, location), vector_length_y(vector_length_y), step_y(1), tune_block_x(true) + TunableKernel1D_base(n_items, location), + vector_length_y(vector_length_y), + step_y(1), + step_y_bkup(1), + tune_block_x(true) { } @@ -317,7 +326,11 @@ namespace quda @brief Resize the autotuning step size in the y dimension @brief[in] y New step size */ - void resizeStep(int y) const { step_y = y; } + void resizeStep(int y) const + { + step_y = y; + step_y_bkup = step_y; + } }; /** @@ -410,6 +423,7 @@ namespace quda using TunableKernel2D_base::vector_length_y; mutable unsigned vector_length_z; mutable unsigned step_z; + mutable unsigned step_z_bkup; bool tune_block_y; /** @@ -478,6 +492,7 @@ namespace quda TunableKernel2D_base(field, vector_length_y, location), vector_length_z(vector_length_z), step_z(1), + step_z_bkup(1), tune_block_y(true) { } @@ -494,6 +509,7 @@ namespace quda TunableKernel2D_base(n_items, vector_length_y, location), vector_length_z(vector_length_z), step_z(1), + step_z_bkup(1), tune_block_y(true) { } @@ -581,6 +597,7 @@ namespace quda void resizeStep(int y, int z) const { step_z = z; + step_z_bkup = step_z; TunableKernel2D_base::resizeStep(y); } }; From a0184d6f6ff8280805ffc8019c305349982ac80b Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 2 Oct 2024 17:12:40 -0700 Subject: [PATCH 094/103] Heterogeneous reductions now break up the device-local partial read and writes into atoms. This fixes an issue with the MR-based Schwarz solver when using heterogenous reductions with batched solves --- include/kernels/blas_core.cuh | 8 +++---- include/targets/cuda/reduce_helper.h | 36 +++++++++++++++++++++++----- include/tunable_reduction.h | 4 ++-- lib/inv_mr_quda.cpp | 4 ++-- 4 files changed, 38 insertions(+), 14 deletions(-) diff --git a/include/kernels/blas_core.cuh b/include/kernels/blas_core.cuh index 969a31fcae..ad97210e2e 100644 --- a/include/kernels/blas_core.cuh +++ b/include/kernels/blas_core.cuh @@ -418,17 +418,17 @@ namespace quda static constexpr memory_access<1, 1, 1> read{ }; static constexpr memory_access<1, 1> write{ }; complex a[MAX_MULTI_RHS] = {}; - double3 *Ar3; + double4 *Ar4; caxpyxmazMR_(cvector &a, cvector &, cvector &) : - Ar3(static_cast(reducer::get_device_buffer())) + Ar4(static_cast(reducer::get_device_buffer())) { for (auto i = 0u; i < a.size(); i++) this->a[i] = a[i]; } template __device__ __host__ void operator()(T &x, T &y, T &z, T &, T &, int j) const { - auto ar3 = Ar3[j]; - auto aj = a[j].real() * complex((real)ar3.x, (real)ar3.y) * ((real)1.0 / (real)ar3.z); + auto ar4 = Ar4[j]; + auto aj = a[j].real() * complex((real)ar4.x, (real)ar4.y) * ((real)1.0 / (real)ar4.z); #pragma unroll for (int i = 0; i < x.size(); i++) { diff --git a/include/targets/cuda/reduce_helper.h b/include/targets/cuda/reduce_helper.h index 73fc0bfbc0..b604f14e52 100644 --- a/include/targets/cuda/reduce_helper.h +++ b/include/targets/cuda/reduce_helper.h @@ -55,7 +55,8 @@ namespace quda bool reset = false; /** reset the counter post completion (required for multiple calls with the same arg instance */ using system_atomic_t = typename atomic_type::type; /** heterogeneous atomics must use lock-free atomics -> operate on scalars */ static constexpr int n_item = sizeof(T) / sizeof(system_atomic_t); /** number of words per reduction variable */ - cuda::atomic *partial; /** device atomic buffer */ + // FIXME on Hopper this could be a 128-bit type + cuda::atomic *partial; /** device atomic buffer */ cuda::atomic *result_d; /** device-mapped host atomic buffer */ cuda::atomic *result_h; /** host atomic buffer */ count_t *count; /** count array that is used to track the number of completed thread blocks at a given batch index */ @@ -78,7 +79,7 @@ namespace quda reset(reset), consumed(false) { - reducer::init(n_reduce, sizeof(*partial)); + reducer::init(n_reduce, n_item * sizeof(*partial)); // these buffers may be allocated in init, so we can't set the local copies until now partial = static_cast(reducer::get_device_buffer()); result_d = static_cast(reducer::get_mapped_buffer()); @@ -209,7 +210,15 @@ namespace quda if (arg.get_output_async_buffer()) { arg.get_output_async_buffer()[idx] = sum; } else { // write to device memory - arg.partial[idx].store(sum, cuda::std::memory_order_relaxed); + // write out the final reduced value + if (tid == 0) { + atomic_t sum_tmp[n]; + memcpy(sum_tmp, &sum, sizeof(sum)); +#pragma unroll + for (unsigned int i = 0; i < n; i++) { + arg.partial[n * idx + i].store(sum_tmp[i], cuda::std::memory_order_relaxed); + } + } } } } @@ -243,6 +252,8 @@ namespace quda template __device__ inline void reduce(Arg &arg, const Reducer &r, const T &in, const int idx) { + using atomic_t = typename atomic_type::type; + constexpr size_t n = sizeof(T) / sizeof(atomic_t); constexpr auto n_batch_block = std::min(Arg::max_n_batch_block, device::max_block_size()); using BlockReduce = BlockReduce; @@ -255,8 +266,14 @@ namespace quda if (target::thread_idx().x == 0 && target::thread_idx().y == 0) { // need to call placement new constructor since partial is not necessarily constructed - new (arg.partial + idx * target::grid_dim().x + target::block_idx().x) - cuda::atomic {aggregate}; + atomic_t aggregate_tmp[n]; + memcpy(aggregate_tmp, &aggregate, sizeof(aggregate)); + +#pragma unroll + for (int k = 0; k < n; k++) { + new (arg.partial + (idx * target::grid_dim().x + target::block_idx().x) * n + k) + cuda::atomic {aggregate_tmp[k]}; + } // increment global block counter for this reduction auto value = arg.count[idx].fetch_add(1, cuda::std::memory_order_release); @@ -272,7 +289,14 @@ namespace quda auto i = target::thread_idx().y * target::block_dim().x + target::thread_idx().x; T sum = r.init(); while (i < target::grid_dim().x) { - sum = r(sum, arg.partial[idx * target::grid_dim().x + i].load(cuda::std::memory_order_relaxed)); + atomic_t partial_tmp[n]; + T partial; +#pragma unroll + for (int k = 0; k < n; k++) { + partial_tmp[k] = arg.partial[(idx * target::grid_dim().x + i) * n + k].load(cuda::std::memory_order_relaxed); + } + memcpy(&partial, partial_tmp, sizeof(partial)); + sum = r(sum, partial); i += target::block_size<2>(); } diff --git a/include/tunable_reduction.h b/include/tunable_reduction.h index c575cb8135..0ec7d4b6c2 100644 --- a/include/tunable_reduction.h +++ b/include/tunable_reduction.h @@ -144,7 +144,7 @@ namespace quda n_items(field.Volume()), block_size_y(block_size_y) { - if (commAsyncReduction()) strcat(aux, "async,"); + if (commAsyncReduction()) strcat(aux, ",async"); } /** @@ -155,7 +155,7 @@ namespace quda TunableReduction2D(size_t n_items, QudaFieldLocation location) : TunableKernel(n_items, location), n_items(n_items), block_size_y(1) { - if (commAsyncReduction()) strcat(aux, "async,"); + if (commAsyncReduction()) strcat(aux, ",async"); } /** diff --git a/lib/inv_mr_quda.cpp b/lib/inv_mr_quda.cpp index d8565286fa..c539782ad7 100644 --- a/lib/inv_mr_quda.cpp +++ b/lib/inv_mr_quda.cpp @@ -64,7 +64,7 @@ namespace quda if (!param.is_preconditioner) getProfile().TPSTART(QUDA_PROFILE_COMPUTE); vector b2 = blas::norm2(b); // Save norm of b - vector r2; // if zero source then we will exit immediately doing no work + vector r2; if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { mat(r, x); @@ -125,7 +125,7 @@ namespace quda } else { // doing local reductions so can make it asynchronous commAsyncReductionSet(true); - blas::cDotProductNormA(Ar, r_sloppy); + blas::cDotProductNormAB(Ar, r_sloppy); // omega*alpha is done in the kernel blas::caxpyXmazMR(param.omega, r_sloppy, x_sloppy, Ar); From e51c59ca1c3fbb264e8c7e639079c56ff5ac663b Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Thu, 3 Oct 2024 16:58:04 -0700 Subject: [PATCH 095/103] ctest should use mrhs for asqtad solver test --- tests/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index efe0886030..b295e00d9c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1204,7 +1204,7 @@ foreach(prec IN LISTS TEST_PRECS) COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} --dslash-type asqtad --ngcrkrylov 8 --compute-fat-long true --dim 6 6 6 8 --prec ${prec} --tol ${tol} --tolhq ${tol} --niter 1000 - --enable-testing true + --enable-testing true --nsrc 4 --nsrc-tile 4 --gtest_output=xml:invert_test_asqtad_${prec}.xml) if(DEFINED ENV{QUDA_ENABLE_TUNING}) From 70a3b7550f56f079f709654d04098cf9c5e32a1a Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 4 Oct 2024 10:21:46 -0700 Subject: [PATCH 096/103] Fix typo --- lib/inv_bicgstab_quda.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/inv_bicgstab_quda.cpp b/lib/inv_bicgstab_quda.cpp index be5c1583ff..2cffc2aeed 100644 --- a/lib/inv_bicgstab_quda.cpp +++ b/lib/inv_bicgstab_quda.cpp @@ -368,7 +368,7 @@ namespace quda { logQuda(QUDA_VERBOSE, "BiCGstab: Reliable updates = %d\n", rUpdate); - if (!param.is_preconditioner) { // do not do the below if we this is an inner solver + if (!param.is_preconditioner) { // do not do the below if this is an inner solver // r2 was freshly computed auto hq = use_heavy_quark_res ? blas::HeavyQuarkResidualNorm(x, r) : vector(b.size(), {}); for (auto i = 0u; i < b.size(); i++) { From 9d4abe93b57bb408e3c174a5466e04c82d11c93b Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 7 Oct 2024 04:23:21 -0700 Subject: [PATCH 097/103] Fix for QudaMultigridParam::dslash_use_mma so that it respects the correct level. Default test batch size for eigenvalue computation is now 16 (to match the default mma nvec instantiation --- lib/block_transpose.in.cu | 2 +- lib/multigrid.cpp | 3 ++- tests/utils/command_line_params.cpp | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/block_transpose.in.cu b/lib/block_transpose.in.cu index 505613fdb7..9d34fa0f40 100644 --- a/lib/block_transpose.in.cu +++ b/lib/block_transpose.in.cu @@ -113,7 +113,7 @@ namespace quda if constexpr (sizeof...(N) > 0) { launch_span_nVec(V, B, nVecs); } else { - errorQuda("nVec = %d not instantiated\n", V.Nvec()); + errorQuda("nVec = %d not instantiated", V.Nvec()); } } } diff --git a/lib/multigrid.cpp b/lib/multigrid.cpp index e36900f155..5241fdeabc 100644 --- a/lib/multigrid.cpp +++ b/lib/multigrid.cpp @@ -408,7 +408,8 @@ namespace quda diracParam.type = QUDA_COARSE_DIRAC; diracParam.halo_precision = param.mg_global.precision_null[param.level]; diracParam.setup_use_mma = param.mg_global.setup_use_mma[param.level]; - diracParam.dslash_use_mma = param.mg_global.dslash_use_mma[param.level]; + // level + 1 since this is for the coarse grid + diracParam.dslash_use_mma = param.mg_global.dslash_use_mma[param.level + 1]; diracParam.allow_truncation = (param.mg_global.allow_truncation == QUDA_BOOLEAN_TRUE) ? true : false; diracCoarseResidual = new DiracCoarse(diracParam, param.setup_location == QUDA_CUDA_FIELD_LOCATION ? true : false, diff --git a/tests/utils/command_line_params.cpp b/tests/utils/command_line_params.cpp index 3c172ebee0..4bf7c96c44 100644 --- a/tests/utils/command_line_params.cpp +++ b/tests/utils/command_line_params.cpp @@ -213,7 +213,7 @@ QudaMemoryType mem_type_ritz = QUDA_MEMORY_DEVICE; // Parameters for the stand alone eigensolver int eig_ortho_block_size = 0; -int eig_evals_batch_size = 4; +int eig_evals_batch_size = 16; int eig_block_size = 4; int eig_n_ev = 16; int eig_n_kr = 32; From a9ef50be4bfb365048d55f110a92929840630322 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 7 Oct 2024 04:37:10 -0700 Subject: [PATCH 098/103] Apply clang format --- include/color_spinor_field.h | 4 ++-- include/invert_quda.h | 2 +- include/reference_wrapper_helper.h | 1 - .../cuda/mma_tensor_op/smma_m16n8_sm80.cuh | 3 ++- include/targets/cuda/reduce_helper.h | 2 +- lib/check_params.h | 4 ++-- lib/dslash_clover_helper.cu | 2 +- lib/dslash_pack2.cu | 2 +- lib/interface_quda.cpp | 4 ++-- lib/inv_cg3_quda.cpp | 4 ++-- lib/inv_gcr_quda.cpp | 8 +++---- lib/inv_mr_quda.cpp | 3 ++- lib/inv_sd_quda.cpp | 3 ++- lib/solve.cpp | 3 +-- lib/solver.cpp | 23 +++++++++++-------- lib/solver.hpp | 4 ++-- tests/asan.h | 17 +++++++------- tests/invert_test.cpp | 18 ++++++--------- tests/staggered_invert_test.cpp | 20 +++++++--------- tests/utils/command_line_params.cpp | 8 ++++--- 20 files changed, 67 insertions(+), 68 deletions(-) diff --git a/include/color_spinor_field.h b/include/color_spinor_field.h index 3023a4e248..18de570604 100644 --- a/include/color_spinor_field.h +++ b/include/color_spinor_field.h @@ -950,7 +950,7 @@ namespace quda functionality is useful for the case where we have multiple temporaries in different precisions, but do not need them simultaneously. Use this functionality with caution. - @param[out] alias The vector of aliased fields + @param[out] alias The vector of aliased fields @param[in] v The vector of fields to alias @param[in] param Parameters for the alias field */ @@ -966,7 +966,7 @@ namespace quda variant is used with std::vector as opposed to vector_ref, and allows for correct resizing. Use this functionality with caution. - @param[out] alias The vector of aliased fields + @param[out] alias The vector of aliased fields @param[in] v The vector of fields to alias @param[in] param Parameters for the alias field */ diff --git a/include/invert_quda.h b/include/invert_quda.h index c235a98192..d9d5279066 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -1115,7 +1115,7 @@ namespace quda { std::vector r; //! residual vector std::vector r_sloppy; //! sloppy residual vector - int k_break = 0; //! track when the solver converged + int k_break = 0; //! track when the solver converged std::vector> p; // GCR direction vectors std::vector> Ap; // mat * direction vectors diff --git a/include/reference_wrapper_helper.h b/include/reference_wrapper_helper.h index 1dbf9f6df8..ed27d17106 100644 --- a/include/reference_wrapper_helper.h +++ b/include/reference_wrapper_helper.h @@ -570,7 +570,6 @@ namespace quda for (auto &v : multiplied) v *= u; return multiplied; } - }; template vector operator*(const T &a, const vector &b) { return b * a; } diff --git a/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh b/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh index b08d5c0b7e..2c801a4f46 100644 --- a/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh +++ b/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh @@ -319,7 +319,8 @@ namespace quda for (int i = 0; i < warp_m * warp_n * thread_count; i++) { reg[i] *= alpha; } } - __device__ inline void axpy(float alpha, OperandC x) { + __device__ inline void axpy(float alpha, OperandC x) + { #pragma unroll for (int i = 0; i < warp_m * warp_n * thread_count; i++) { reg[i] += alpha * x.reg[i]; } } diff --git a/include/targets/cuda/reduce_helper.h b/include/targets/cuda/reduce_helper.h index b604f14e52..da3e178bcf 100644 --- a/include/targets/cuda/reduce_helper.h +++ b/include/targets/cuda/reduce_helper.h @@ -56,7 +56,7 @@ namespace quda using system_atomic_t = typename atomic_type::type; /** heterogeneous atomics must use lock-free atomics -> operate on scalars */ static constexpr int n_item = sizeof(T) / sizeof(system_atomic_t); /** number of words per reduction variable */ // FIXME on Hopper this could be a 128-bit type - cuda::atomic *partial; /** device atomic buffer */ + cuda::atomic *partial; /** device atomic buffer */ cuda::atomic *result_d; /** device-mapped host atomic buffer */ cuda::atomic *result_h; /** host atomic buffer */ count_t *count; /** count array that is used to track the number of completed thread blocks at a given batch index */ diff --git a/lib/check_params.h b/lib/check_params.h index 266f12d727..51ed9acc38 100644 --- a/lib/check_params.h +++ b/lib/check_params.h @@ -472,8 +472,8 @@ void printQudaInvertParam(QudaInvertParam *param) { if (param->num_src_per_sub_partition < 1) errorQuda("Invalid num_src_per_subpartition = %d", param->num_src_per_sub_partition); if (param->num_src % param->num_src_per_sub_partition != 0) - errorQuda("num_src %d not compatible with num_src_per_sub_partition %d", - param->num_src, param->num_src_per_sub_partition); + errorQuda("num_src %d not compatible with num_src_per_sub_partition %d", param->num_src, + param->num_src_per_sub_partition); } #endif #endif diff --git a/lib/dslash_clover_helper.cu b/lib/dslash_clover_helper.cu index d8923a7044..2fc342bfb4 100644 --- a/lib/dslash_clover_helper.cu +++ b/lib/dslash_clover_helper.cu @@ -125,7 +125,7 @@ namespace quda { long long bytes() const { long long rtn = out.Bytes() + in.Bytes() + in.size() * clover.Bytes() / (3 - in.SiteSubset()); if (twist == QUDA_TWIST_GAMMA5_INVERSE && !clover::dynamic_inverse()) - rtn += in.size() * clover.Bytes() / (3 - in.SiteSubset()); + rtn += in.size() * clover.Bytes() / (3 - in.SiteSubset()); return rtn; } }; diff --git a/lib/dslash_pack2.cu b/lib/dslash_pack2.cu index 46580ad6f1..043ea65abc 100644 --- a/lib/dslash_pack2.cu +++ b/lib/dslash_pack2.cu @@ -341,7 +341,7 @@ public: #endif } else { - errorQuda("Unsupported nSpin = %d", in.Nspin()); + errorQuda("Unsupported nSpin = %d", in.Nspin()); } } diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 9e1fc8b2a0..5575d36a54 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -4592,7 +4592,7 @@ void computeCloverForceQuda(void *h_mom, double dt, void **h_x, void **, double // Make sure extendedGaugeResident has the correct R if (extendedGaugeResident) delete extendedGaugeResident; lat_dim_t R; - for (int d=0; d<4; d++) R[d] = (d==0 ? 2 : 1) * (redundant_comms || commDimPartitioned(d)); + for (int d = 0; d < 4; d++) R[d] = (d == 0 ? 2 : 1) * (redundant_comms || commDimPartitioned(d)); extendedGaugeResident = createExtendedGauge(*gaugePrecise, R, getProfile()); GaugeField &gaugeEx = *extendedGaugeResident; @@ -4667,7 +4667,7 @@ void computeTMCloverForceQuda(void *h_mom, void **h_x, void **h_x0, double *coef // Make sure extendedGaugeResident has the correct R if (extendedGaugeResident) delete extendedGaugeResident; lat_dim_t R; - for (int d=0; d<4; d++) R[d] = (d==0 ? 2 : 1) * (redundant_comms || commDimPartitioned(d)); + for (int d = 0; d < 4; d++) R[d] = (d == 0 ? 2 : 1) * (redundant_comms || commDimPartitioned(d)); extendedGaugeResident = createExtendedGauge(*gaugePrecise, R, profileTMCloverForce); GaugeField &gaugeEx = *extendedGaugeResident; diff --git a/lib/inv_cg3_quda.cpp b/lib/inv_cg3_quda.cpp index 54de943b22..75d6cc93d9 100644 --- a/lib/inv_cg3_quda.cpp +++ b/lib/inv_cg3_quda.cpp @@ -134,7 +134,7 @@ namespace quda { vector rho(b.size(), 1.0); vector gamma(b.size(), 1.0); - while ( !convergence(r2, heavy_quark_res, stop, stop_hq) && k < param.maxiter) { + while (!convergence(r2, heavy_quark_res, stop, stop_hq) && k < param.maxiter) { matSloppy(ArS, rS); auto gamma_old = gamma; @@ -177,7 +177,7 @@ namespace quda { if (convergence(r2, heavy_quark_res, stop, stop_hq) && param.delta >= param.tol) update = true; // For heavy-quark inversion force a reliable update if we continue after - if ( use_heavy_quark_res and L2breakdown and convergenceHQ(heavy_quark_res, stop_hq) and param.delta >= param.tol ) { + if (use_heavy_quark_res and L2breakdown and convergenceHQ(heavy_quark_res, stop_hq) and param.delta >= param.tol) { update = true; } diff --git a/lib/inv_gcr_quda.cpp b/lib/inv_gcr_quda.cpp index d0d8031413..8e9c1df32e 100644 --- a/lib/inv_gcr_quda.cpp +++ b/lib/inv_gcr_quda.cpp @@ -66,7 +66,7 @@ namespace quda { for (int i=0; i0; r--) { @@ -302,7 +302,7 @@ namespace quda { k_break = 0; PrintStats("GCR", total_iter+k, r2, b2, heavy_quark_res); - while ( !convergence(r2, heavy_quark_res, stop, stop_hq) && total_iter < param.maxiter) { + while (!convergence(r2, heavy_quark_res, stop, stop_hq) && total_iter < param.maxiter) { if (K) { pushVerbosity(param.verbosity_precondition); @@ -389,7 +389,7 @@ namespace quda { k_break = k; k = 0; - if ( !convergence(r2, heavy_quark_res, stop, stop_hq) ) { + if (!convergence(r2, heavy_quark_res, stop, stop_hq)) { restart++; // restarting if residual is still too great PrintStats("GCR (restart)", restart, r2, b2, heavy_quark_res); @@ -398,7 +398,7 @@ namespace quda { r2_old = r2; // prevent ending the Krylov space prematurely if other convergence criteria not met - if (r2 < stop) l2_converge = true; + if (r2 < stop) l2_converge = true; } r2_old = r2; diff --git a/lib/inv_mr_quda.cpp b/lib/inv_mr_quda.cpp index c539782ad7..1696cd4d42 100644 --- a/lib/inv_mr_quda.cpp +++ b/lib/inv_mr_quda.cpp @@ -69,7 +69,8 @@ namespace quda if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { mat(r, x); r2 = blas::xmyNorm(b, r); // r = b - Ax0 - for (auto i = 0u; i < b.size(); i++) if (b2[i] == 0) b2[i] = r2[i]; + for (auto i = 0u; i < b.size(); i++) + if (b2[i] == 0) b2[i] = r2[i]; } else { r2 = b2; blas::copy(r, b); diff --git a/lib/inv_sd_quda.cpp b/lib/inv_sd_quda.cpp index 929064ee28..ba74e5059c 100644 --- a/lib/inv_sd_quda.cpp +++ b/lib/inv_sd_quda.cpp @@ -44,7 +44,8 @@ namespace quda { // Compute the true residual mat(r, x); r2 = blas::xmyNorm(b, r); - for (auto i = 0u; i < b.size(); i++) if (b2[i] == 0) b2[i] = r2[i]; + for (auto i = 0u; i < b.size(); i++) + if (b2[i] == 0) b2[i] = r2[i]; } else { blas::zero(x); blas::copy(r, b); diff --git a/lib/solve.cpp b/lib/solve.cpp index f06d5b6bd6..d4eb2bed8d 100644 --- a/lib/solve.cpp +++ b/lib/solve.cpp @@ -10,8 +10,7 @@ namespace quda void flushChrono(int i) { - if (i >= QUDA_MAX_CHRONO) - errorQuda("Requested chrono index %d is outside of max %d", i, QUDA_MAX_CHRONO); + if (i >= QUDA_MAX_CHRONO) errorQuda("Requested chrono index %d is outside of max %d", i, QUDA_MAX_CHRONO); if (i >= 0) chronoResident[i].clear(); diff --git a/lib/solver.cpp b/lib/solver.cpp index 877c7c1ac8..2f9be27701 100644 --- a/lib/solver.cpp +++ b/lib/solver.cpp @@ -413,7 +413,8 @@ namespace quda { // check the L2 relative residual norm if necessary if ((param.residual_type & QUDA_L2_RELATIVE_RESIDUAL) || (param.residual_type & QUDA_L2_ABSOLUTE_RESIDUAL)) { - if (std::isnan(r2[i]) || std::isinf(r2[i])) errorQuda("Solver appears to have diverged with residual %9.6e", r2[i]); + if (std::isnan(r2[i]) || std::isinf(r2[i])) + errorQuda("Solver appears to have diverged with residual %9.6e", r2[i]); if (r2[i] > r2_tol[i]) return false; } } @@ -444,7 +445,8 @@ namespace quda { for (auto i = 0u; i < r2.size(); i++) { // check the L2 relative residual norm if necessary if ((param.residual_type & QUDA_L2_RELATIVE_RESIDUAL) || (param.residual_type & QUDA_L2_ABSOLUTE_RESIDUAL)) { - if (std::isnan(r2[i]) || std::isinf(r2[i])) errorQuda("Solver appears to have diverged with residual %9.6e", r2[i]); + if (std::isnan(r2[i]) || std::isinf(r2[i])) + errorQuda("Solver appears to have diverged with residual %9.6e", r2[i]); if (r2[i] > r2_tol[i]) return false; } } @@ -467,10 +469,11 @@ namespace quda { for (auto i = 0u; i < r2.size(); i++) { auto rhs_str = set_rhs_str(i, r2.size()); if (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) { - logQuda(QUDA_VERBOSE, "%s: %5d iterations, %s = %9.6e, |r|/|b| = %9.6e, heavy-quark residual = %9.6e\n", name, - k, rhs_str.c_str(), r2[i], sqrt(r2[i] / b2[i]), hq2[i]); + logQuda(QUDA_VERBOSE, "%s: %5d iterations, %s = %9.6e, |r|/|b| = %9.6e, heavy-quark residual = %9.6e\n", + name, k, rhs_str.c_str(), r2[i], sqrt(r2[i] / b2[i]), hq2[i]); } else { - logQuda(QUDA_VERBOSE, "%s: %5d iterations, %s = %9.6e, |r|/|b| = %9.6e\n", name, k, rhs_str.c_str(), r2[i], sqrt(r2[i] / b2[i])); + logQuda(QUDA_VERBOSE, "%s: %5d iterations, %s = %9.6e, |r|/|b| = %9.6e\n", name, k, rhs_str.c_str(), r2[i], + sqrt(r2[i] / b2[i])); } if (std::isnan(r2[i]) || std::isinf(r2[i])) errorQuda("Solver appears to have diverged for n = %d", i); @@ -492,7 +495,8 @@ namespace quda { logQuda(QUDA_SUMMARIZE, "%s: Convergence at %d iterations, %sL2 relative residual: iterated = %9.6e, true = %9.6e " "(requested = %9.6e), heavy-quark residual = %9.6e (requested = %9.6e)\n", - name, k, rhs_str.c_str(), sqrt(r2[i] / b2[i]), param.true_res[i], sqrt(r2_tol[i] / b2[i]), param.true_res_hq[i], hq_tol[i]); + name, k, rhs_str.c_str(), sqrt(r2[i] / b2[i]), param.true_res[i], sqrt(r2_tol[i] / b2[i]), + param.true_res_hq[i], hq_tol[i]); } else { logQuda(QUDA_SUMMARIZE, "%s: Convergence at %d iterations, %sL2 relative residual: iterated = %9.6e, true = %9.6e " @@ -504,11 +508,12 @@ namespace quda { logQuda(QUDA_SUMMARIZE, "%s: Convergence at %d iterations, %sL2 relative residual: iterated = %9.6e " "(requested = %9.6e), heavy-quark residual = %9.6e (requested = %9.6e)\n", - name, k, rhs_str.c_str(), sqrt(r2[i] / b2[i]), sqrt(r2_tol[i] / b2[i]), param.true_res_hq[i], hq_tol[i]); + name, k, rhs_str.c_str(), sqrt(r2[i] / b2[i]), sqrt(r2_tol[i] / b2[i]), param.true_res_hq[i], + hq_tol[i]); } else { logQuda(QUDA_SUMMARIZE, - "%s: Convergence at %d iterations, %sL2 relative residual: iterated = %9.6e (requested = %9.6e)\n", name, - k, rhs_str.c_str(), sqrt(r2[i] / b2[i]), sqrt(r2_tol[i] / b2[i])); + "%s: Convergence at %d iterations, %sL2 relative residual: iterated = %9.6e (requested = %9.6e)\n", + name, k, rhs_str.c_str(), sqrt(r2[i] / b2[i]), sqrt(r2_tol[i] / b2[i])); } } } diff --git a/lib/solver.hpp b/lib/solver.hpp index 4bf6ffb618..5c308e461c 100644 --- a/lib/solver.hpp +++ b/lib/solver.hpp @@ -61,8 +61,8 @@ namespace quda */ template void Solver::computeCAKrylovSpace(const DiracMatrix &diracm, std::vector> &Ap, - std::vector> &p, int n_krylov, - QudaCABasis basis, double m_map, double b_map, Args &&...args) + std::vector> &p, int n_krylov, QudaCABasis basis, + double m_map, double b_map, Args &&...args) { // in some cases p or Ap may be larger if (static_cast(p.size()) < n_krylov) errorQuda("Invalid p.size() %lu < n_krylov %d", p.size(), n_krylov); diff --git a/tests/asan.h b/tests/asan.h index a56732248f..1c18092210 100644 --- a/tests/asan.h +++ b/tests/asan.h @@ -2,13 +2,12 @@ extern "C" { - /** - @brief Set the default ASAN options. This ensures that QUDA just - works when SANITIZE is enabled without requiring ASAN_OPTIONS to - be set. We default disable leak checking, otherwise this will - cause ctest to fail with MPI library leaks. This declaration - cannot be in the test library, and must be in the test executable. - */ - const char *__asan_default_options() { return "detect_leaks=0,protect_shadow_gap=0"; } - +/** + @brief Set the default ASAN options. This ensures that QUDA just + works when SANITIZE is enabled without requiring ASAN_OPTIONS to + be set. We default disable leak checking, otherwise this will + cause ctest to fail with MPI library leaks. This declaration + cannot be in the test library, and must be in the test executable. +*/ +const char *__asan_default_options() { return "detect_leaks=0,protect_shadow_gap=0"; } } diff --git a/tests/invert_test.cpp b/tests/invert_test.cpp index 7005075d9b..2c945cc77c 100644 --- a/tests/invert_test.cpp +++ b/tests/invert_test.cpp @@ -240,9 +240,8 @@ std::vector> solve(test_t param) printfQuda("MG Setup Done: %g secs, %g Gflops\n", mg_param.invert_param->secs, mg_param.invert_param->gflops / mg_param.invert_param->secs); if (mg_param.invert_param->energy > 0) { - printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", - mg_param.invert_param->energy, mg_param.invert_param->power, - mg_param.invert_param->temp, mg_param.invert_param->clock); + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", mg_param.invert_param->energy, + mg_param.invert_param->power, mg_param.invert_param->temp, mg_param.invert_param->clock); } } @@ -338,8 +337,8 @@ std::vector> solve(test_t param) printfQuda("Done: %i iter / %g secs = %g Gflops\n", inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs); if (inv_param.energy > 0) { - printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", - inv_param.energy, inv_param.power, inv_param.temp, inv_param.clock); + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", inv_param.energy, + inv_param.power, inv_param.temp, inv_param.clock); } } @@ -373,14 +372,11 @@ std::vector> solve(test_t param) quda::comm_allreduce_sum(inv_param.gflops); inv_param.gflops /= quda::comm_size() / num_sub_partition; quda::comm_allreduce_max(inv_param.secs); - printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops, %g secs per source\n", - num_sub_partition, inv_param.iter, - inv_param.secs, inv_param.gflops / inv_param.secs, - inv_param.secs / Nsrc_tile); + printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition, + inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs, inv_param.secs / Nsrc_tile); if (inv_param.energy > 0) { printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n", - inv_param.energy, inv_param.energy / Nsrc_tile, - inv_param.power, inv_param.temp, inv_param.clock); + inv_param.energy, inv_param.energy / Nsrc_tile, inv_param.power, inv_param.temp, inv_param.clock); } } } diff --git a/tests/staggered_invert_test.cpp b/tests/staggered_invert_test.cpp index 7f1c8dd053..a098324ef6 100644 --- a/tests/staggered_invert_test.cpp +++ b/tests/staggered_invert_test.cpp @@ -286,9 +286,8 @@ std::vector> solve(test_t param) printfQuda("MG Setup Done: %g secs, %g Gflops\n", mg_param.invert_param->secs, mg_param.invert_param->gflops / mg_param.invert_param->secs); if (mg_param.invert_param->energy > 0) { - printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", - mg_param.invert_param->energy, mg_param.invert_param->power, - mg_param.invert_param->temp, mg_param.invert_param->clock); + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n", mg_param.invert_param->energy, + mg_param.invert_param->power, mg_param.invert_param->temp, mg_param.invert_param->clock); } } @@ -395,8 +394,8 @@ std::vector> solve(test_t param) printfQuda("Done: %i iter / %g secs = %g Gflops\n", inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs); if (inv_param.energy > 0) { - printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n\n", - inv_param.energy, inv_param.power, inv_param.temp, inv_param.clock); + printfQuda("Energy = %g J, Mean power = %g W, mean temp = %g C, mean clock = %f\n\n", inv_param.energy, + inv_param.power, inv_param.temp, inv_param.clock); } } } else { @@ -413,8 +412,7 @@ std::vector> solve(test_t param) _hp_b[i] = in[j + i].data(); } - if (inv_deflate) - eig_param.preserve_deflation = j < Nsrc - Nsrc_tile ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; + if (inv_deflate) eig_param.preserve_deflation = j < Nsrc - Nsrc_tile ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; invertMultiSrcQuda(_hp_x.data(), _hp_b.data(), &inv_param); // move residuals to (i+j)^th location for verification after solves have finished @@ -428,13 +426,11 @@ std::vector> solve(test_t param) quda::comm_allreduce_sum(inv_param.gflops); inv_param.gflops /= comm_size() / num_sub_partition; quda::comm_allreduce_max(inv_param.secs); - printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops, %g secs per source\n", - num_sub_partition, inv_param.iter, - inv_param.secs, inv_param.gflops / inv_param.secs, inv_param.secs / Nsrc_tile); + printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition, + inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs, inv_param.secs / Nsrc_tile); if (inv_param.energy > 0) { printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n\n", - inv_param.energy, inv_param.energy / Nsrc_tile, - inv_param.power, inv_param.temp, inv_param.clock); + inv_param.energy, inv_param.energy / Nsrc_tile, inv_param.power, inv_param.temp, inv_param.clock); } } } diff --git a/tests/utils/command_line_params.cpp b/tests/utils/command_line_params.cpp index 4bf7c96c44..43d2745035 100644 --- a/tests/utils/command_line_params.cpp +++ b/tests/utils/command_line_params.cpp @@ -791,7 +791,8 @@ void add_eigen_option_group(std::shared_ptr quda_app) opgroup->add_option("--eig-ortho-block-size", eig_ortho_block_size, "The block size to use when orthonormalising vectors in hybrid modified Gram-Schmidt" "0 for always Classical, 1 for Modified, n > 1 for Hybrid)"); - opgroup->add_option("--eig-evals-batch-size", eig_evals_batch_size, "The batch size used when computing eigenvalues in the eigensolver"); + opgroup->add_option("--eig-evals-batch-size", eig_evals_batch_size, + "The batch size used when computing eigenvalues in the eigensolver"); opgroup->add_option("--eig-block-size", eig_block_size, "The block size to use in the block variant eigensolver"); opgroup->add_option( "--eig-n-ev-deflate", eig_n_ev_deflate, @@ -1014,8 +1015,9 @@ void add_multigrid_option_group(std::shared_ptr quda_app) "The number of pre-smoother applications to do at a given multigrid level (default 2)"); quda_app->add_mgoption(opgroup, "--mg-nvec", nvec, CLI::PositiveNumber, "Number of null-space vectors to define the multigrid transfer operator on a given level"); - quda_app->add_mgoption(opgroup, "--mg-nvec-batch", nvec_batch, CLI::PositiveNumber, - "Batch size to use when computing the null-space vectors to define the multigrid transfer operator on a given level"); + quda_app->add_mgoption( + opgroup, "--mg-nvec-batch", nvec_batch, + CLI::PositiveNumber, "Batch size to use when computing the null-space vectors to define the multigrid transfer operator on a given level"); opgroup->add_option("--mg-oblique-proj-check", oblique_proj_check, "Measure how well the null vector subspace adjusts the low eigenmode subspace (default false)"); opgroup->add_option("--mg-omega", omega, From 5c3192aea84254ea3cf0a578e72eb71fe533f749 Mon Sep 17 00:00:00 2001 From: Evan Weinberg Date: Tue, 8 Oct 2024 02:14:15 -0700 Subject: [PATCH 099/103] Updated the MILC HISQ MG interface for setting batch sizes --- lib/milc_interface.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/milc_interface.cpp b/lib/milc_interface.cpp index 24443b0a1f..bcf255201a 100644 --- a/lib/milc_interface.cpp +++ b/lib/milc_interface.cpp @@ -2148,6 +2148,7 @@ void milcSetMultigridEigParam(QudaEigParam &mg_eig_param, mgInputStruct &input_s mg_eig_param.n_kr = input_struct.deflate_n_kr; // mg_eig_n_kr[level]; mg_eig_param.n_conv = input_struct.nvec[level]; mg_eig_param.n_ev_deflate = -1; // deflate everything that converged + mg_eig_param.compute_evals_batch_size = 16; // compute the eigenvalues in appropriate batches mg_eig_param.batched_rotate = 0; // mg_eig_batched_rotate[level]; mg_eig_param.require_convergence = QUDA_BOOLEAN_TRUE; // mg_eig_require_convergence[level] ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; @@ -2328,6 +2329,7 @@ void milcSetMultigridParam(milcMultigridPack *mg_pack, QudaPrecision host_precis mg_param.setup_maxiter_refresh[i] = 0; // setup_maxiter_refresh[i]; mg_param.n_vec[i] = (i == 0) ? ((input_struct.optimized_kd == QUDA_TRANSFER_COARSE_KD) ? 24 : 3) : input_struct.nvec[i]; + mg_param.n_vec_batch[i] = 16; mg_param.n_block_ortho[i] = 2; // n_block_ortho[i]; // number of times to Gram-Schmidt mg_param.precision_null[i] = input_struct.preconditioner_precision; // precision to store the null-space basis mg_param.smoother_halo_precision[i] From 7903288629f0fcc474989fec5a1393ecc17a4b42 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 8 Oct 2024 02:58:24 -0700 Subject: [PATCH 100/103] Set QudaMultigridParam::n_vec_batch to invalid to force user to set this. Remove duplicate code --- lib/check_params.h | 8 +++++--- tests/utils/set_params.cpp | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/check_params.h b/lib/check_params.h index 51ed9acc38..f227ee4716 100644 --- a/lib/check_params.h +++ b/lib/check_params.h @@ -976,12 +976,14 @@ void printQudaMultigridParam(QudaMultigridParam *param) { #endif #ifdef INIT_PARAM - if (i Date: Tue, 8 Oct 2024 04:35:59 -0700 Subject: [PATCH 101/103] Made nvec_batch more robust in the MILC HISQ MG interface --- lib/milc_interface.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/milc_interface.cpp b/lib/milc_interface.cpp index bcf255201a..23473eaa4b 100644 --- a/lib/milc_interface.cpp +++ b/lib/milc_interface.cpp @@ -2148,7 +2148,7 @@ void milcSetMultigridEigParam(QudaEigParam &mg_eig_param, mgInputStruct &input_s mg_eig_param.n_kr = input_struct.deflate_n_kr; // mg_eig_n_kr[level]; mg_eig_param.n_conv = input_struct.nvec[level]; mg_eig_param.n_ev_deflate = -1; // deflate everything that converged - mg_eig_param.compute_evals_batch_size = 16; // compute the eigenvalues in appropriate batches + mg_eig_param.compute_evals_batch_size = (input_struct.nvec[level] % 16 == 0) ? 16 : 1; // compute the eigenvalues in appropriate batches mg_eig_param.batched_rotate = 0; // mg_eig_batched_rotate[level]; mg_eig_param.require_convergence = QUDA_BOOLEAN_TRUE; // mg_eig_require_convergence[level] ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE; @@ -2329,7 +2329,7 @@ void milcSetMultigridParam(milcMultigridPack *mg_pack, QudaPrecision host_precis mg_param.setup_maxiter_refresh[i] = 0; // setup_maxiter_refresh[i]; mg_param.n_vec[i] = (i == 0) ? ((input_struct.optimized_kd == QUDA_TRANSFER_COARSE_KD) ? 24 : 3) : input_struct.nvec[i]; - mg_param.n_vec_batch[i] = 16; + mg_param.n_vec_batch[i] = (i == 0) ? 1 : (mg_param.n_vec[i] % 16 == 0 ? 16 : 1); mg_param.n_block_ortho[i] = 2; // n_block_ortho[i]; // number of times to Gram-Schmidt mg_param.precision_null[i] = input_struct.preconditioner_precision; // precision to store the null-space basis mg_param.smoother_halo_precision[i] From 2d56bfda5e56c66dcb86b13c55342d3c9f43cd57 Mon Sep 17 00:00:00 2001 From: Mahias Wagner Date: Wed, 9 Oct 2024 13:59:52 +0200 Subject: [PATCH 102/103] bump CPM (silences some warnings with newer cmake) --- cmake/CPM.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/CPM.cmake b/cmake/CPM.cmake index a78671296c..baf2d8c344 100644 --- a/cmake/CPM.cmake +++ b/cmake/CPM.cmake @@ -2,8 +2,8 @@ # # SPDX-FileCopyrightText: Copyright (c) 2019-2023 Lars Melchior and contributors -set(CPM_DOWNLOAD_VERSION 0.38.5) -set(CPM_HASH_SUM "192aa0ccdc57dfe75bd9e4b176bf7fb5692fd2b3e3f7b09c74856fc39572b31c") +set(CPM_DOWNLOAD_VERSION 0.40.2) +set(CPM_HASH_SUM "c8cdc32c03816538ce22781ed72964dc864b2a34a310d3b7104812a5ca2d835d") if(CPM_SOURCE_CACHE) set(CPM_DOWNLOAD_LOCATION "${CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake") From 05b2bc6e1d1f2012b19b13da9a72fa5cb21f8b00 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 9 Oct 2024 14:16:48 -0700 Subject: [PATCH 103/103] Fix typo --- lib/solver.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/solver.hpp b/lib/solver.hpp index 5c308e461c..2482b8e6b3 100644 --- a/lib/solver.hpp +++ b/lib/solver.hpp @@ -51,8 +51,8 @@ namespace quda /** @brief Generate a Krylov space in a given basis @param[in] diracm Dirac matrix used to generate the Krylov space - @param[out] Ap dirac matrix times the Krylov basis vector sets - @param[in,out] p Krylov basis vector sest; assumes p[0] is in place + @param[out] Ap Dirac matrix times the Krylov basis vector sets + @param[in,out] p Krylov basis vector set; assumes p[0] is in place @param[in] n_krylov Size of krylov space @param[in] basis Basis type @param[in] m_map Slope mapping for Chebyshev basis; ignored for power basis