Skip to content

Commit b7754db

Browse files
authored
Refactored Function Signatures and Type Hints (#42)
- Removed `Optional` from function arguments where the default value is not `None`, adhering to best practices and resolving redundancies.
1 parent 0ebfad7 commit b7754db

15 files changed

Lines changed: 120 additions & 125 deletions

File tree

tnco/app/app.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -147,19 +147,18 @@ def load(binary):
147147

148148
def load_tn(obj: Any,
149149
*,
150-
fuse: Optional[float] = 4,
151-
decompose_hyper_inds: Optional[bool] = True,
152-
simplify_circuit: Optional[bool] = True,
153-
initial_state: Optional[Union[str, Dict[Qubit, Matrix],
154-
None]] = '0',
155-
final_state: Optional[Union[str, Dict[Qubit, Matrix], None]] = '0',
156-
output_index_token: Optional[str] = '*',
157-
sparse_index_token: Optional[str] = '/',
158-
atol: Optional[float] = 1e-5,
150+
fuse: float = 4,
151+
decompose_hyper_inds: bool = True,
152+
simplify_circuit: bool = True,
153+
initial_state: Union[str, Dict[Qubit, Matrix], None] = '0',
154+
final_state: Union[str, Dict[Qubit, Matrix], None] = '0',
155+
output_index_token: str = '*',
156+
sparse_index_token: str = '/',
157+
atol: float = 1e-5,
159158
dtype: Optional[Any] = None,
160159
backend: Optional[str] = None,
161160
seed: Optional[int] = None,
162-
verbose: Optional[int] = False) -> TensorNetwork:
161+
verbose: int = False) -> TensorNetwork:
163162
"""Loads a tensor network from various object types.
164163
165164
The function loads a tensor network from ``obj`` of any type. See Notes for
@@ -570,8 +569,8 @@ def dump_results(tn: TensorNetwork,
570569
*,
571570
output_format: Optional[str] = None,
572571
output_filename: Optional[str] = None,
573-
output_compression: Optional[str] = 'auto',
574-
overwrite_output_file: Optional[bool] = False,
572+
output_compression: str = 'auto',
573+
overwrite_output_file: bool = False,
575574
**kwargs) -> Any:
576575
"""Dumps results to a file or returns them.
577576
@@ -745,18 +744,18 @@ class BaseOptimizer:
745744
verbose: If ``True``, prints verbose output.
746745
"""
747746
max_width: Optional[float] = None
748-
n_jobs: Optional[int] = -1
749-
width_type: Optional[str] = 'float32'
750-
cost_type: Optional[str] = 'float64'
747+
n_jobs: int = -1
748+
width_type: str = 'float32'
749+
cost_type: str = 'float64'
751750
output_format: Optional[str] = None
752751
output_filename: Optional[str] = None
753-
output_compression: Optional[str] = 'auto'
754-
overwrite_output_file: Optional[bool] = False
755-
atol: Optional[float] = 1e-5
752+
output_compression: str = 'auto'
753+
overwrite_output_file: bool = False
754+
atol: float = 1e-5
756755
dtype: Optional[Any] = None
757756
backend: Optional[str] = None
758757
seed: Optional[int] = None
759-
verbose: Optional[int] = False
758+
verbose: int = False
760759

761760
def optimize(self, *args: Any, **kwargs: Any) -> Any:
762761
raise NotImplementedError()
@@ -787,20 +786,20 @@ def __post_init__(self) -> None:
787786
self._dump_results(None, None, check_only=True)
788787

789788

790-
def Optimizer(method: Optional[str] = 'sa',
789+
def Optimizer(method: str = 'sa',
791790
max_width: Optional[float] = None,
792-
n_jobs: Optional[int] = -1,
793-
width_type: Optional[str] = 'float32',
794-
cost_type: Optional[str] = 'float64',
791+
n_jobs: int = -1,
792+
width_type: str = 'float32',
793+
cost_type: str = 'float64',
795794
output_format: Optional[str] = None,
796795
output_filename: Optional[str] = None,
797-
output_compression: Optional[str] = 'auto',
798-
overwrite_output_file: Optional[bool] = False,
799-
atol: Optional[float] = 1e-5,
796+
output_compression: str = 'auto',
797+
overwrite_output_file: bool = False,
798+
atol: float = 1e-5,
800799
dtype: Optional[Any] = None,
801800
backend: Optional[str] = None,
802801
seed: Optional[int] = None,
803-
verbose: Optional[int] = False) -> BaseOptimizer:
802+
verbose: int = False) -> BaseOptimizer:
804803
"""Factory function to create an optimizer.
805804
806805
Optimize the tensor network.

tnco/app/circuit/sampling.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -95,20 +95,20 @@ def sample(
9595
circuit: Union[Iterable[Tuple[Matrix, Tuple[Qubit]]],
9696
SamplingIntermediateState],
9797
optimizer: Optimizer,
98-
n_samples: Optional[int] = 1,
98+
n_samples: int = 1,
9999
*,
100-
simplify: Optional[bool] = True,
101-
use_matrix_commutation: Optional[bool] = True,
102-
decompose_hyper_inds: Optional[bool] = True,
103-
fuse: Optional[float] = 4,
100+
simplify: bool = True,
101+
use_matrix_commutation: bool = True,
102+
decompose_hyper_inds: bool = True,
103+
fuse: float = 4,
104104
qubit_order: Optional[Iterable[Qubit]] = None,
105-
normalize: Optional[bool] = True,
106-
return_intermediate_state_only: Optional[bool] = False,
105+
normalize: bool = True,
106+
return_intermediate_state_only: bool = False,
107107
dtype: Optional[Any] = None,
108108
optimization_backend: Optional[str] = None,
109109
contraction_backend: Optional[str] = None,
110110
seed: Optional[int] = None,
111-
verbose: Optional[int] = False,
111+
verbose: int = False,
112112
**optimize_params
113113
) -> Union[Tuple[Dict[str, int], Tuple[Qubit]], SamplingIntermediateState]:
114114
"""Sample bitstrings from a circuit.
@@ -442,14 +442,14 @@ class Sampler:
442442
verbose: Verbose output.
443443
"""
444444
max_width: Optional[float] = None
445-
n_jobs: Optional[int] = -1
446-
width_type: Optional[str] = 'float32'
447-
cost_type: Optional[str] = 'float64'
448-
atol: Optional[float] = 1e-5
445+
n_jobs: int = -1
446+
width_type: str = 'float32'
447+
cost_type: str = 'float64'
448+
atol: float = 1e-5
449449
dtype: Optional[Any] = None
450450
optimization_backend: Optional[str] = None
451451
seed: Optional[int] = None
452-
verbose: Optional[int] = False
452+
verbose: int = False
453453

454454
def __post_init__(self):
455455
# Get rng
@@ -475,15 +475,15 @@ def __post_init__(self):
475475
def sample(
476476
self,
477477
circuit: Union[Circuit, SamplingIntermediateState],
478-
n_samples: Optional[int] = 1,
478+
n_samples: int = 1,
479479
*,
480-
simplify: Optional[bool] = True,
481-
use_matrix_commutation: Optional[bool] = True,
482-
decompose_hyper_inds: Optional[bool] = True,
483-
fuse: Optional[float] = 4,
480+
simplify: bool = True,
481+
use_matrix_commutation: bool = True,
482+
decompose_hyper_inds: bool = True,
483+
fuse: float = 4,
484484
qubit_order: Optional[Iterable[Qubit]] = None,
485-
normalize: Optional[bool] = True,
486-
return_intermediate_state_only: Optional[bool] = False,
485+
normalize: bool = True,
486+
return_intermediate_state_only: bool = False,
487487
contraction_backend: Optional[str] = None,
488488
**optimize_params
489489
) -> Union[Tuple[Dict[str, int], Tuple[Qubit]], SamplingIntermediateState]:

tnco/app/finite_width/sa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ def optimize(self,
113113
tn: Any,
114114
betas: Union[Tuple[float, float], Iterable[float]],
115115
n_steps: Optional[int] = None,
116-
n_runs: Optional[int] = 1,
116+
n_runs: int = 1,
117117
n_projs: Optional[int] = None,
118-
update_slices: Optional[int] = 10,
118+
update_slices: int = 10,
119119
timeout: Optional[float] = None,
120120
**load_tn_options) -> Any:
121121
"""Optimizes the tensor network ``tn``.

tnco/app/infinite_memory/sa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def optimize(self,
9797
tn: Any,
9898
betas: Union[Tuple[float, float], Iterable[float]],
9999
n_steps: Optional[int] = None,
100-
n_runs: Optional[int] = 1,
100+
n_runs: int = 1,
101101
n_projs: Optional[int] = None,
102102
timeout: Optional[float] = None,
103103
**load_tn_options) -> Any:

tnco/ctree.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def __init__(self,
7070
dims: Union[Dict[Index, int], int],
7171
*,
7272
output_inds: Optional[Iterable[Index]] = None,
73-
check_shared_inds: Optional[bool] = False,
74-
verbose: Optional[bool] = False,
73+
check_shared_inds: bool = False,
74+
verbose: bool = False,
7575
**kwargs) -> None:
7676
# Get cache if present
7777
_cache = kwargs.pop('_cache', None)
@@ -407,7 +407,7 @@ def max_width(self) -> float:
407407
def traverse_tree(ctree: ContractionTree,
408408
callback: Callable[[int], NoReturn],
409409
*,
410-
verbose: Optional[int] = False) -> NoReturn:
410+
verbose: int = False) -> NoReturn:
411411
"""Traverses ``tree`` and calls ``callback`` for each node.
412412
413413
Traverses ``tree`` and calls ``callback`` for each node of the tree. The

tnco/optimize/finite_width/cost_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ class SimpleCostModel(BaseCostModel):
7373
def __init__(self,
7474
max_width: float,
7575
*,
76-
width_type: Optional[Literal['float32', 'float64',
77-
'float128']] = 'float32',
78-
cost_type: Optional[Literal['float32', 'float64', 'float128',
79-
'float1024']] = 'float64',
76+
width_type: Literal['float32', 'float64',
77+
'float128'] = 'float32',
78+
cost_type: Literal['float32', 'float64', 'float128',
79+
'float1024'] = 'float64',
8080
sparse_inds: Optional[Iterable[Index]] = None,
8181
n_projs: Optional[int] = None):
8282

tnco/optimize/finite_width/optimizer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ def __init__(self,
5555
ctree: ContractionTree,
5656
cmodel: BaseCostModel,
5757
*,
58-
slice_update: Optional[Literal['greedy']] = 'greedy',
58+
slice_update: Literal['greedy'] = 'greedy',
5959
max_number_new_slices: int = 0,
60-
skip_slices: Iterable[Index] = None,
60+
skip_slices: Optional[Iterable[Index]] = None,
6161
seed: Optional[Union[int, str]] = None,
62-
disable_shared_inds: Optional[bool] = False,
63-
atol: Optional[float] = 1e-5,
62+
disable_shared_inds: bool = False,
63+
atol: float = 1e-5,
6464
**kwargs) -> None:
6565

6666
# Check cost model
@@ -292,7 +292,7 @@ def min_slices(self) -> FrozenSet[Index]:
292292
def is_valid(self,
293293
*,
294294
atol: float = 1e-5,
295-
return_message: str = False) -> bool:
295+
return_message: bool = False) -> bool:
296296
"""Check if ``Optimizer`` is in a valid state.
297297
298298
Check if ``Optimizer`` is in a valid state.

tnco/optimize/infinite_memory/cost_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ class SimpleCostModel(BaseCostModel):
6464

6565
def __init__(self,
6666
*,
67-
cost_type: Optional[Literal['float32', 'float64', 'float128',
68-
'float1024']] = 'float64',
67+
cost_type: Literal['float32', 'float64', 'float128',
68+
'float1024'] = 'float64',
6969
sparse_inds: Optional[Iterable[Index]] = None,
7070
n_projs: Optional[int] = None):
7171

tnco/optimize/infinite_memory/optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def __init__(self,
4949
cmodel: BaseCostModel,
5050
*,
5151
seed: Optional[Union[int, str]] = None,
52-
disable_shared_inds: Optional[bool] = False,
53-
atol: Optional[float] = 1e-5,
52+
disable_shared_inds: bool = False,
53+
atol: float = 1e-5,
5454
**kwargs) -> None:
5555

5656
# Check cost model
@@ -200,7 +200,7 @@ def log2_min_total_cost(self) -> float:
200200
"""
201201
return self._optimizer.log2_min_total_cost
202202

203-
def is_valid(self, *, atol: float = 1e-5, return_message: str = False):
203+
def is_valid(self, *, atol: float = 1e-5, return_message: bool = False):
204204
"""Check if ``Optimizer`` is in a valid state.
205205
206206
Check if ``Optimizer`` is in a valid state.

tnco/optimize/prob.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Probability functions for optimization."""
1515

1616
from importlib import import_module
17-
from typing import Literal, Optional
17+
from typing import Literal
1818
from warnings import warn
1919

2020
__all__ = [
@@ -23,9 +23,8 @@
2323

2424

2525
def BaseProbability(*,
26-
cost_type: Optional[Literal['float32', 'float64',
27-
'float128',
28-
'float1024']] = 'float64'):
26+
cost_type: Literal['float32', 'float64', 'float128',
27+
'float1024'] = 'float64'):
2928
"""Factory for acceptance probability (always accept).
3029
3130
Always accept any proposed move.
@@ -56,8 +55,8 @@ def BaseProbability(*,
5655

5756

5857
def Greedy(*,
59-
cost_type: Optional[Literal['float32', 'float64', 'float128',
60-
'float1024']] = 'float64'):
58+
cost_type: Literal['float32', 'float64', 'float128',
59+
'float1024'] = 'float64'):
6160
"""Factory for greedy probability.
6261
6362
Accept a move only if the cost is not increasing.
@@ -91,9 +90,8 @@ def Greedy(*,
9190

9291
def SimulatedAnnealing(beta: float = 0,
9392
*,
94-
cost_type: Optional[Literal['float32', 'float64',
95-
'float128',
96-
'float1024']] = 'float64'):
93+
cost_type: Literal['float32', 'float64', 'float128',
94+
'float1024'] = 'float64'):
9795
"""Factory for Simulated Annealing (Deprecated).
9896
9997
Accept a move using the Metropolis-Hastings probability. Deprecated in
@@ -119,9 +117,8 @@ def SimulatedAnnealing(beta: float = 0,
119117

120118
def MetropolisHastings(beta: float = 0,
121119
*,
122-
cost_type: Optional[Literal['float32', 'float64',
123-
'float128',
124-
'float1024']] = 'float64'):
120+
cost_type: Literal['float32', 'float64', 'float128',
121+
'float1024'] = 'float64'):
125122
"""Factory for Metropolis-Hastings probability.
126123
127124
Accept a move using the Metropolis-Hastings probability.

0 commit comments

Comments
 (0)