Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions source/pip/benchmarks/bench_qre.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import timeit
from dataclasses import dataclass, KW_ONLY, field
from qsharp.qre import linear_function, generic_function, instruction
from qsharp.qre import linear_function, generic_function
from qsharp.qre._architecture import _make_instruction
from qsharp.qre.models import AQREGateBased, SurfaceCode
from qsharp.qre._enumeration import _enumerate_instances

Expand Down Expand Up @@ -39,7 +40,7 @@ def bench_enumerate_isas():
sys.path.append(os.path.join(os.path.dirname(__file__), "../tests"))
from test_qre import ExampleLogicalFactory, ExampleFactory # type: ignore

ctx = AQREGateBased().context()
ctx = AQREGateBased(gate_time=50, measurement_time=100).context()

# Hierarchical factory using from_components
query = SurfaceCode.q() * ExampleLogicalFactory.q(
Expand All @@ -62,7 +63,7 @@ def bench_enumerate_isas():
def bench_function_evaluation_linear():
fl = linear_function(12)

inst = instruction(42, arity=None, space=fl, time=1, error_rate=1.0)
inst = _make_instruction(42, 0, None, 1, fl, None, 1.0, {})
number = 1000
duration = timeit.timeit(
"inst.space(5)",
Expand All @@ -83,7 +84,7 @@ def func(arity: int) -> int:

fg = generic_function(func)

inst = instruction(42, arity=None, space=fg, time=1, error_rate=1.0)
inst = _make_instruction(42, 0, None, 1, fg, None, 1.0, {})
number = 1000
duration = timeit.timeit(
"inst.space(5)",
Expand Down
15 changes: 11 additions & 4 deletions source/pip/qsharp/qre/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@

from ._application import Application
from ._architecture import Architecture
from ._estimation import estimate
from ._estimation import (
estimate,
EstimationTable,
EstimationTableColumn,
EstimationTableEntry,
)
from ._instruction import (
LOGICAL,
PHYSICAL,
Encoding,
ISATransform,
PropertyKey,
constraint,
instruction,
InstructionSource,
)
from ._isa_enumeration import ISAQuery, ISARefNode, ISA_ROOT
Expand All @@ -31,14 +35,13 @@
linear_function,
instruction_name,
)
from ._trace import LatticeSurgery, PSSPC, TraceQuery
from ._trace import LatticeSurgery, PSSPC, TraceQuery, TraceTransform

__all__ = [
"block_linear_function",
"constant_function",
"constraint",
"estimate",
"instruction",
"linear_function",
"Application",
"Architecture",
Expand All @@ -47,6 +50,9 @@
"ConstraintBound",
"Encoding",
"EstimationResult",
"EstimationTable",
"EstimationTableColumn",
"EstimationTableEntry",
"FactoryResult",
"generic_function",
"instruction_name",
Expand All @@ -63,6 +69,7 @@
"PSSPC",
"Trace",
"TraceQuery",
"TraceTransform",
"LOGICAL",
"PHYSICAL",
]
60 changes: 54 additions & 6 deletions source/pip/qsharp/qre/_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import types
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from types import NoneType
from typing import (
Any,
ClassVar,
Generic,
Protocol,
Expand All @@ -18,7 +18,8 @@
)

from ._enumeration import _enumerate_instances
from ._qre import Trace
from ._qre import Trace, EstimationResult
from ._trace import TraceQuery


class DataclassProtocol(Protocol):
Expand Down Expand Up @@ -50,9 +51,19 @@ class Application(ABC, Generic[TraceParameters]):
def get_trace(self, parameters: TraceParameters) -> Trace:
"""Return the trace corresponding to this application."""

def context(self, **kwargs) -> _Context:
@staticmethod
def q(**kwargs) -> TraceQuery:
return TraceQuery(NoneType, **kwargs)

def context(self) -> _Context:
"""Create a new enumeration context for this application."""
return _Context(self, **kwargs)
return _Context(self)

def post_process(
self, parameters: TraceParameters, estimation: EstimationResult
) -> EstimationResult:
"""Post-process an estimation result for a given set of trace parameters."""
return estimation

def enumerate_traces(
self,
Expand Down Expand Up @@ -80,15 +91,52 @@ def enumerate_traces(
for instances in _enumerate_instances(cast(type, param_type), **kwargs):
yield self.get_trace(instances)

def enumerate_traces_with_parameters(
self,
**kwargs,
) -> Generator[tuple[TraceParameters, Trace], None, None]:
"""Yields (parameters, trace) pairs for an application.

Like ``enumerate_traces``, but each yielded trace is accompanied by the
trace parameters that were used to generate it.

Args:
**kwargs: Domain overrides forwarded to ``_enumerate_instances``.

Returns:
Generator[tuple[TraceParameters, Trace], None, None]: A generator
of (parameters, trace) pairs.
"""

param_type = get_type_hints(self.__class__.get_trace).get("parameters")
if param_type is types.NoneType:
yield None, self.get_trace(None) # type: ignore
return

if isinstance(param_type, TypeVar):
for c in param_type.__constraints__:
if c is not types.NoneType:
param_type = c
break

if self._parallel_traces:
instances = list(_enumerate_instances(cast(type, param_type), **kwargs))
with ThreadPoolExecutor() as executor:
for instance, trace in zip(
instances, executor.map(self.get_trace, instances)
):
yield instance, trace
else:
for instance in _enumerate_instances(cast(type, param_type), **kwargs):
yield instance, self.get_trace(instance)

def disable_parallel_traces(self):
"""Disable parallel trace generation for this application."""
self._parallel_traces = False


class _Context:
application: Application
kwargs: dict[str, Any]

def __init__(self, application: Application, **kwargs):
self.application = application
self.kwargs = kwargs
Loading