Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions qualtran/_infra/bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
GeneralizerT,
SympySymbolAllocator,
)
from qualtran.simulation.classical_sim import ClassicalValRetT, ClassicalValT
from qualtran.simulation.classical_sim import ClassicalValRetT, ClassicalValT, MeasurementPhase
from qualtran.simulation.tensor import DiscardInd


Expand Down Expand Up @@ -286,7 +286,9 @@ def on_classical_vals(
except NotImplementedError as e:
raise NotImplementedError(f"{self} does not support classical simulation: {e}") from e

def basis_state_phase(self, **vals: 'ClassicalValT') -> Union[complex, None]:
def basis_state_phase(
self, **vals: 'ClassicalValT'
) -> Union[complex, 'MeasurementPhase', None]:
"""How this bloq phases classical basis states.

Override this method if your bloq represents classical logic with basis-state
Expand All @@ -297,7 +299,8 @@ def basis_state_phase(self, **vals: 'ClassicalValT') -> Union[complex, None]:
(X, CNOT, Toffoli, ...) and diagonal operations (T, CZ, CCZ, ...).

Bloq authors should override this method. If you are using an instantiated bloq object,
call TODO and not this method directly.
call `qualtran.simulation.classical_sim.do_phased_classical_simulation` or use
`qualtran.simulation.classical_sim.PhasedClassicalSimState`.

If this method is implemented, `on_classical_vals` must also be implemented.
If `on_classical_vals` is implemented but this method is not implemented, it is assumed
Expand Down
21 changes: 14 additions & 7 deletions qualtran/_infra/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from qualtran.cirq_interop import CirqQuregT
from qualtran.drawing import WireSymbol
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValRetT, ClassicalValT
from qualtran.simulation.classical_sim import ClassicalValRetT, ClassicalValT, MeasurementPhase

ControlBit: TypeAlias = int
"""A control bit, either 0 or 1."""
Expand Down Expand Up @@ -380,10 +380,7 @@ def ctrl_spec(self) -> 'CtrlSpec':

@cached_property
def _thru_registers_only(self) -> bool:
for reg in self.subbloq.signature:
if reg.side != Side.THRU:
return False
return True
return self.signature.thru_registers_only

@staticmethod
def _make_ctrl_system(cb: '_ControlledBase') -> Tuple['_ControlledBase', 'AddControlledT']:
Expand Down Expand Up @@ -453,7 +450,9 @@ def on_classical_vals(self, **vals: 'ClassicalValT') -> Mapping[str, 'ClassicalV

return vals

def basis_state_phase(self, **vals: 'ClassicalValT') -> Union[complex, None]:
def basis_state_phase(
self, **vals: 'ClassicalValT'
) -> Union[complex, 'MeasurementPhase', None]:
"""Phasing action of controlled bloqs.

This involves conditionally doing the phasing action of `subbloq`. All implementers
Expand Down Expand Up @@ -533,7 +532,15 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
from qualtran.drawing import Text

if reg is None:
return Text(f'C[{self.subbloq}]')
sub_title = self.subbloq.wire_symbol(None, idx)
if not isinstance(sub_title, Text):
raise ValueError(
f"{self.subbloq} should return a `Text` object for reg=None wire symbol."
)
if sub_title.text == '':
return Text('')

return Text(f'C[{sub_title.text}]')
if reg.name not in self.ctrl_reg_names:
# Delegate to subbloq
return self.subbloq.wire_symbol(reg, idx)
Expand Down
8 changes: 8 additions & 0 deletions qualtran/_infra/registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import enum
import itertools
from collections import defaultdict
from functools import cached_property
from typing import cast, Dict, Iterable, Iterator, List, overload, Tuple, Union

import attrs
Expand Down Expand Up @@ -230,6 +231,13 @@ def build_from_dtypes(cls, **registers: QCDType) -> 'Signature':
"""
return cls(Register(name=k, dtype=v) for k, v in registers.items() if v.num_qubits)

@cached_property
def thru_registers_only(self) -> bool:
for reg in self:
if reg.side != Side.THRU:
return False
return True

def lefts(self) -> Iterable[Register]:
"""Iterable over all registers that appear on the LEFT as input."""
yield from self._lefts.values()
Expand Down
5 changes: 4 additions & 1 deletion qualtran/bloqs/mcmt/and_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def on_classical_vals(
return {'ctrl': ctrl, 'target': out}

# Uncompute
assert target == out
if target != out:
raise ValueError(
f"Inconsistent `target` found for uncomputing `And`: {ctrl=}, {target=}. Expected target={out}"
)
return {'ctrl': ctrl}

def my_tensors(
Expand Down
15 changes: 12 additions & 3 deletions qualtran/resource_counting/_bloq_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,15 @@ class QECGatesCost(CostKey[GateCounts]):
legacy_shims: bool = False

def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], GateCounts]) -> GateCounts:
from qualtran.bloqs.basic_gates import GlobalPhase, Identity, Toffoli, TwoBitCSwap
from qualtran.bloqs.basic_gates import (
Discard,
GlobalPhase,
Identity,
MeasX,
MeasZ,
Toffoli,
TwoBitCSwap,
)
from qualtran.bloqs.basic_gates._shims import Measure
from qualtran.bloqs.bookkeeping._bookkeeping_bloq import _BookkeepingBloq
from qualtran.bloqs.mcmt import And, MultiTargetCNOT
Expand Down Expand Up @@ -326,7 +334,7 @@ def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], GateCounts])
return GateCounts(toffoli=1)

# Measurement
if isinstance(bloq, Measure):
if isinstance(bloq, (Measure, MeasZ, MeasX)):
return GateCounts(measurement=1)

# 'And' bloqs
Expand Down Expand Up @@ -370,9 +378,10 @@ def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], GateCounts])
return GateCounts()

# Bookkeeping, empty bloqs
if isinstance(bloq, _BookkeepingBloq) or isinstance(bloq, (GlobalPhase, Identity)):
if isinstance(bloq, _BookkeepingBloq) or isinstance(bloq, (GlobalPhase, Identity, Discard)):
return GateCounts()

# Rotations
if bloq_is_rotation(bloq):
return GateCounts(rotation=1)

Expand Down
Loading