Skip to content
8 changes: 4 additions & 4 deletions include/quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,10 @@ extern "C" {
/** The t0 parameter for distance preconditioning, the timeslice where the source is located */
int distance_pc_t0;

/** Whether to use the smeared gauge field for the Dirac operator, usually
when defined as a spatial Laplacian: mainly used in computing Laplacian eigenvectors */
QudaBoolean use_smeared_gauge;

} QudaInvertParam;

// Parameter set for solving eigenvalue problems.
Expand Down Expand Up @@ -505,10 +509,6 @@ extern "C" {
false, but preserve_deflation would be true */
QudaBoolean preserve_evals;

/** Whether to use the smeared gauge field for the Dirac operator
for whose eigenvalues are are computing. */
bool use_smeared_gauge;

/** What type of Dirac operator we are using **/
/** If !(use_norm_op) && !(use_dagger) use M. **/
/** If use_dagger, use Mdag **/
Expand Down
2 changes: 1 addition & 1 deletion lib/check_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ void printQudaEigParam(QudaEigParam *param) {
P(preserve_deflation, QUDA_BOOLEAN_FALSE);
P(preserve_deflation_space, 0);
P(preserve_evals, QUDA_BOOLEAN_TRUE);
P(use_smeared_gauge, false);
P(use_dagger, QUDA_BOOLEAN_FALSE);
P(use_norm_op, QUDA_BOOLEAN_FALSE);
P(compute_svd, QUDA_BOOLEAN_FALSE);
Expand Down Expand Up @@ -373,6 +372,7 @@ void printQudaInvertParam(QudaInvertParam *param) {
P(twist_flavor, QUDA_TWIST_INVALID);
P(laplace3D, INVALID_INT);
P(covdev_mu, INVALID_INT);
P(use_smeared_gauge, QUDA_BOOLEAN_FALSE);
#else
// asqtad and domain wall use mass parameterization
if (param->dslash_type == QUDA_STAGGERED_DSLASH || param->dslash_type == QUDA_ASQTAD_DSLASH
Expand Down
44 changes: 29 additions & 15 deletions lib/interface_quda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1436,9 +1436,10 @@ namespace quda {

void setDiracParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc)
{
GaugeField *gaugePtr = (!inv_param->use_smeared_gauge) ? gaugePrecise : gaugeSmeared;
double kappa = inv_param->kappa;
if (inv_param->dirac_order == QUDA_CPS_WILSON_DIRAC_ORDER) {
kappa *= gaugePrecise->Anisotropy();
kappa *= gaugePtr->Anisotropy();
}

switch (inv_param->dslash_type) {
Expand Down Expand Up @@ -1528,7 +1529,7 @@ namespace quda {

diracParam.matpcType = inv_param->matpc_type;
diracParam.dagger = inv_param->dagger;
diracParam.gauge = inv_param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatPrecise : gaugePrecise;
diracParam.gauge = inv_param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatPrecise : gaugePtr;
diracParam.fatGauge = gaugeFatPrecise;
diracParam.longGauge = gaugeLongPrecise;
diracParam.clover = cloverPrecise;
Expand Down Expand Up @@ -1562,7 +1563,7 @@ namespace quda {
diracParam.commDim[i] = 1; // comms are always on
}

if (diracParam.gauge->Precision() != inv_param->cuda_prec_sloppy)
if ((!inv_param->use_smeared_gauge) && (diracParam.gauge->Precision() != inv_param->cuda_prec_sloppy))
errorQuda("Gauge precision %d does not match requested precision %d\n", diracParam.gauge->Precision(),
inv_param->cuda_prec_sloppy);
}
Expand All @@ -1580,7 +1581,7 @@ namespace quda {
diracParam.commDim[i] = 1; // comms are always on
}

if (diracParam.gauge->Precision() != inv_param->cuda_prec_refinement_sloppy)
if ((!inv_param->use_smeared_gauge) && (diracParam.gauge->Precision() != inv_param->cuda_prec_refinement_sloppy))
errorQuda("Gauge precision %d does not match requested precision %d\n", diracParam.gauge->Precision(),
inv_param->cuda_prec_refinement_sloppy);
}
Expand Down Expand Up @@ -1612,24 +1613,37 @@ namespace quda {
diracParam.gauge = gaugeFatPrecondition;
}

if (diracParam.gauge->Precision() != inv_param->cuda_prec_precondition)
if ((!inv_param->use_smeared_gauge) && (diracParam.gauge->Precision() != inv_param->cuda_prec_precondition))
errorQuda("Gauge precision %d does not match requested precision %d\n", diracParam.gauge->Precision(),
inv_param->cuda_prec_precondition);
}

void setDiracEigParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc, bool use_smeared_gauge)
void setDiracEigParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc)
{
setDiracParam(diracParam, inv_param, pc);

if (inv_param->overlap) {
diracParam.gauge = inv_param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatExtended : gaugeExtended;
diracParam.fatGauge = gaugeFatExtended;
diracParam.longGauge = gaugeLongExtended;
} else if (use_smeared_gauge) {
} else if (inv_param->use_smeared_gauge) {
if (!gaugeSmeared) errorQuda("No smeared gauge field present");
if (inv_param->dslash_type == QUDA_LAPLACE_DSLASH) {
if (gaugeSmeared->GhostExchange() == QUDA_GHOST_EXCHANGE_EXTENDED) {
GaugeFieldParam gauge_param(*gaugePrecise);
GaugeFieldParam gauge_param((gaugePrecise)? *gaugePrecise : *gaugeSmeared);
if (!gaugePrecise){
for (int k=0;k<gauge_param.nDim;++k){
gauge_param.x[k]-=2*gauge_param.r[k]; gauge_param.r[k]=0;} // smearedGauge is loaded as extended, so remove extensions
#ifdef MULTI_GPU
int x_face_size = gauge_param.x[1] * gauge_param.x[2] * gauge_param.x[3] / 2;
int y_face_size = gauge_param.x[0] * gauge_param.x[2] * gauge_param.x[3] / 2;
int z_face_size = gauge_param.x[0] * gauge_param.x[1] * gauge_param.x[3] / 2;
int t_face_size = gauge_param.x[0] * gauge_param.x[1] * gauge_param.x[2] / 2;
gauge_param.pad = std::max({x_face_size, y_face_size, z_face_size, t_face_size});
#endif
//gauge_param.link_type = QUDA_WILSON_LINKS;
gauge_param.ghostExchange = QUDA_GHOST_EXCHANGE_PAD;}
gauge_param.ghostExchange = QUDA_GHOST_EXCHANGE_PAD;
GaugeField gaugeEig(gauge_param);
copyExtendedGauge(gaugeEig, *gaugeSmeared, QUDA_CUDA_FIELD_LOCATION);
gaugeEig.exchangeGhost();
Expand All @@ -1644,6 +1658,7 @@ namespace quda {
diracParam.fatGauge = gaugeFatEigensolver;
diracParam.longGauge = gaugeLongEigensolver;
}

diracParam.clover = cloverEigensolver;

for (int i = 0; i < 4; i++) { diracParam.commDim[i] = 1; }
Expand Down Expand Up @@ -1697,8 +1712,7 @@ namespace quda {
dRef = Dirac::create(diracRefParam);
}

void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam &param, bool pc_solve,
bool use_smeared_gauge)
void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam &param, bool pc_solve)
{
DiracParam diracParam;
DiracParam diracSloppyParam;
Expand All @@ -1709,7 +1723,7 @@ namespace quda {
setDiracSloppyParam(diracSloppyParam, &param, pc_solve);
bool pre_comms_flag = (param.schwarz_type != QUDA_INVALID_SCHWARZ) ? false : true;
setDiracPreParam(diracPreParam, &param, pc_solve, pre_comms_flag);
setDiracEigParam(diracEigParam, &param, pc_solve, use_smeared_gauge);
setDiracEigParam(diracEigParam, &param, pc_solve);

d = Dirac::create(diracParam); // create the Dirac operator
dSloppy = Dirac::create(diracSloppyParam);
Expand Down Expand Up @@ -2406,6 +2420,7 @@ void checkClover(QudaInvertParam *param) {
quda::GaugeField *checkGauge(QudaInvertParam *param)
{
quda::GaugeField *U = param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatPrecise :
param->use_smeared_gauge ? gaugeSmeared :
gaugePrecise;

if (U == nullptr)
Expand All @@ -2415,7 +2430,7 @@ quda::GaugeField *checkGauge(QudaInvertParam *param)
errorQuda("Solve precision %d doesn't match gauge precision %d", param->cuda_prec, U->Precision());
}

if (param->dslash_type != QUDA_ASQTAD_DSLASH) {
if (param->dslash_type != QUDA_ASQTAD_DSLASH && !param->use_smeared_gauge) {
if (param->cuda_prec_sloppy != gaugeSloppy->Precision()
|| param->cuda_prec_precondition != gaugePrecondition->Precision()
|| param->cuda_prec_refinement_sloppy != gaugeRefinement->Precision()
Expand All @@ -2433,7 +2448,7 @@ quda::GaugeField *checkGauge(QudaInvertParam *param)
if (gaugeRefinement == nullptr) errorQuda("Refinement gauge field doesn't exist");
if (gaugeEigensolver == nullptr) errorQuda("Refinement gauge field doesn't exist");
if (param->overlap && gaugeExtended == nullptr) errorQuda("Extended gauge field doesn't exist");
} else {
} else if (!param->use_smeared_gauge) {
if (gaugeLongPrecise == nullptr) errorQuda("Precise gauge long field doesn't exist");

if (param->cuda_prec_sloppy != gaugeFatSloppy->Precision()
Expand Down Expand Up @@ -2585,10 +2600,9 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam

// Create the dirac operator with a sloppy and a precon.
bool pc_solve = (inv_param->solve_type == QUDA_DIRECT_PC_SOLVE) || (inv_param->solve_type == QUDA_NORMOP_PC_SOLVE);
createDiracWithEig(d, dSloppy, dPre, dEig, *inv_param, pc_solve, eig_param->use_smeared_gauge);
createDiracWithEig(d, dSloppy, dPre, dEig, *inv_param, pc_solve);
Dirac &dirac = *dEig;
//------------------------------------------------------

// Construct vectors
//------------------------------------------------------
// Create host wrappers around application vector set
Expand Down
6 changes: 2 additions & 4 deletions lib/solve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,7 @@ namespace quda
getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE);
}

void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam &param, bool pc_solve,
bool use_smeared_gauge);
void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam &param, bool pc_solve);

extern std::vector<ColorSpinorField> solutionResident;

Expand Down Expand Up @@ -349,8 +348,7 @@ namespace quda

// Create the dirac operator and operators for sloppy, precondition,
// and an eigensolver
createDiracWithEig(dirac, diracSloppy, diracPre, diracEig, param, pc_solve,
param.eig_param ? static_cast<QudaEigParam *>(param.eig_param)->use_smeared_gauge : false);
createDiracWithEig(dirac, diracSloppy, diracPre, diracEig, param, pc_solve);

// wrap CPU host side pointers
ColorSpinorParam cpuParam(hp_b[0], param, u.X(), pc_solution, param.input_location);
Expand Down
2 changes: 1 addition & 1 deletion tests/staggered_eigensolve_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ std::vector<double> eigensolve(test_t test_param)
eig_inv_param.solution_type = eig_param.use_pc ? QUDA_MATPC_SOLUTION : QUDA_MAT_SOLUTION;

// whether we are using the resident smeared gauge or not
eig_param.use_smeared_gauge = gauge_smear;
eig_param.invert_param->use_smeared_gauge = (gauge_smear ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE);

if (dslash_type == QUDA_LAPLACE_DSLASH) {
int dimension = laplace3D < 4 ? 3 : 4;
Expand Down
Loading