-
Notifications
You must be signed in to change notification settings - Fork 65
typing: adding types to extensions #251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,22 +1,24 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import getopt | ||
| import os | ||
| import re | ||
| import sqlite3 | ||
| import sys | ||
| from collections import OrderedDict | ||
| from typing import Any | ||
|
|
||
|
|
||
| def get_tperiods(inp_f): | ||
| def get_tperiods(inp_f: str) -> dict[str, list[int]]: | ||
| file_ty = re.search(r'(\w+)\.(\w+)\b', inp_f) # Extract the input filename and extension | ||
|
|
||
| if not file_ty: | ||
| raise 'The file type %s is not recognized.' % inp_f | ||
| raise Exception(f'The file type {inp_f} is not recognized.') | ||
|
|
||
| elif file_ty.group(2) not in ('db', 'sqlite', 'sqlite3', 'sqlitedb'): | ||
| raise Exception('Please specify a database for finding scenarios') | ||
|
Comment on lines
+12
to
19
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧹 Nitpick | 🔵 Trivial Prefer Also applies to: 47-55 🧰 Tools🪛 Ruff (0.14.10)16-16: Create your own exception (TRY002) 16-16: Avoid specifying long messages outside the exception class (TRY003) 19-19: Create your own exception (TRY002) 19-19: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||
|
|
||
| periods_list = {} | ||
| periods_set = set() | ||
| periods_list: dict[str, list[int]] = {} | ||
|
|
||
| con = sqlite3.connect(inp_f) | ||
| cur = con.cursor() # a database cursor is a control structure that enables traversal over | ||
|
|
@@ -30,7 +32,7 @@ def get_tperiods(inp_f): | |
| x.append(row[0]) | ||
| for y in x: | ||
| cur.execute( | ||
| "SELECT DISTINCT period FROM output_flow_out WHERE scenario is '" + str(y) + "'" | ||
| 'SELECT DISTINCT period FROM output_flow_out WHERE scenario = ?', (y,) | ||
| ) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| periods_list[y] = [] | ||
| for per in cur: | ||
|
|
@@ -42,17 +44,16 @@ def get_tperiods(inp_f): | |
| return dict(OrderedDict(sorted(periods_list.items(), key=lambda x: x[1]))) | ||
|
|
||
|
|
||
| def get_scenario(inp_f): | ||
| def get_scenario(inp_f: str) -> dict[str, str]: | ||
| file_ty = re.search(r'(\w+)\.(\w+)\b', inp_f) # Extract the input filename and extension | ||
|
|
||
| if not file_ty: | ||
| raise 'The file type %s is not recognized.' % inp_f | ||
| raise Exception(f'The file type {inp_f} is not recognized.') | ||
|
|
||
| elif file_ty.group(2) not in ('db', 'sqlite', 'sqlite3', 'sqlitedb'): | ||
| raise Exception('Please specify a database for finding scenarios') | ||
|
|
||
| scene_list = {} | ||
| scene_set = set() | ||
| scene_list: dict[str, str] = {} | ||
|
|
||
| con = sqlite3.connect(inp_f) | ||
| cur = con.cursor() # a database cursor is a control structure that enables traversal over | ||
|
|
@@ -70,9 +71,9 @@ def get_scenario(inp_f): | |
| return dict(OrderedDict(sorted(scene_list.items(), key=lambda x: x[1]))) | ||
|
|
||
|
|
||
| def get_comm(inp_f, db_dat): | ||
| comm_list = {} | ||
| comm_set = set() | ||
| def get_comm(inp_f: str, db_dat: bool) -> OrderedDict[str, str]: | ||
| comm_list: dict[str, str] = {} | ||
| comm_set: set[str] = set() | ||
|
Comment on lines
+74
to
+76
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Boolean positional params ( Also applies to: 142-145 🧰 Tools🪛 Ruff (0.14.10)74-74: Boolean-typed positional argument in function definition (FBT001) 🤖 Prompt for AI Agents |
||
| is_query_empty = False | ||
|
|
||
| if not db_dat: | ||
|
|
@@ -138,9 +139,9 @@ def get_comm(inp_f, db_dat): | |
| return OrderedDict(sorted(comm_list.items(), key=lambda x: x[1])) | ||
|
|
||
|
|
||
| def get_tech(inp_f, db_dat): | ||
| tech_list = {} | ||
| tech_set = set() | ||
| def get_tech(inp_f: str, db_dat: bool) -> OrderedDict[str, str]: | ||
| tech_list: dict[str, str] = {} | ||
| tech_set: set[str] = set() | ||
| is_query_empty = False | ||
|
|
||
| if not db_dat: | ||
|
|
@@ -199,13 +200,13 @@ def get_tech(inp_f, db_dat): | |
| return OrderedDict(sorted(tech_list.items(), key=lambda x: x[1])) | ||
|
|
||
|
|
||
| def is_db_overwritten(db_file, inp_dat_file): | ||
| def is_db_overwritten(db_file: str, inp_dat_file: str) -> bool: | ||
| if os.path.basename(db_file) == '0': | ||
| return False | ||
|
|
||
| try: | ||
| con = sqlite3.connect(db_file) | ||
| except: | ||
| except sqlite3.Error: | ||
| return False | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| cur = con.cursor() # A database cursor enables traversal over DB records | ||
| con.text_factory = str # This ensures data is explored with UTF-8 encoding | ||
|
|
@@ -214,15 +215,15 @@ def is_db_overwritten(db_file, inp_dat_file): | |
| # IF output file is empty database. | ||
| cur.execute('SELECT * FROM Technology') | ||
| is_db_empty = False # False for empty db file | ||
| for elem in cur: | ||
| for _ in cur: | ||
| is_db_empty = True # True for non-empty db file | ||
| break | ||
| # This file could be schema with populated results from previous run. Or it could be a normal | ||
| # db file. | ||
| if is_db_empty: | ||
| cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='input_file';") | ||
| does_input_file_table_exist = False | ||
| for i in cur: # This means that the 'input_file' table exists in db. | ||
| for _ in cur: # This means that the 'input_file' table exists in db. | ||
| does_input_file_table_exist = True | ||
| if does_input_file_table_exist: # This block distinguishes normal database from schema. | ||
| # This is schema file. | ||
|
|
@@ -247,7 +248,7 @@ def is_db_overwritten(db_file, inp_dat_file): | |
| return False | ||
|
|
||
|
|
||
| def help_user(): | ||
| def help_user() -> None: | ||
| print( | ||
| """Use as: | ||
| python get_comm_tech.py -i (or --input) <input filename> | ||
|
|
@@ -259,8 +260,8 @@ def help_user(): | |
| ) | ||
|
|
||
|
|
||
| def get_info(inputs): | ||
| inp_file = None | ||
| def get_info(inputs: dict[str, str]) -> Any: | ||
| inp_file: str | None = None | ||
|
Comment on lines
+263
to
+264
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major
This dispatcher only returns period lists or string-keyed mappings, so 🧰 Tools🪛 Ruff (0.15.6)[warning] 263-338: Missing explicit Add explicit (RET503) [warning] 263-263: Too many branches (22 > 12) (PLR0912) [warning] 263-263: Dynamically typed expressions (typing.Any) are disallowed in (ANN401) 🤖 Prompt for AI Agents |
||
| tech_flag = False | ||
| comm_flag = False | ||
| scene = False | ||
|
|
@@ -317,8 +318,8 @@ def get_info(inputs): | |
|
|
||
| else: | ||
| print( | ||
| 'The input file type %s is not recognized. Please specify a database or a text file.' | ||
| % inp_file | ||
| f'The input file type {inp_file} is not recognized. Please specify a database ' | ||
| 'or a text file.' | ||
| ) | ||
| sys.exit(2) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,28 +1,28 @@ | ||||||||||
| # from __future__ import division | ||||||||||
| import time | ||||||||||
| from __future__ import annotations | ||||||||||
|
|
||||||||||
| import csv | ||||||||||
| import sqlite3 | ||||||||||
| from importlib import resources | ||||||||||
| from pathlib import Path | ||||||||||
| from typing import Any | ||||||||||
|
|
||||||||||
| from joblib import Parallel, delayed # type: ignore[import-untyped] | ||||||||||
| from numpy import array | ||||||||||
| from pyomo.dataportal import DataPortal | ||||||||||
| from SALib.analyze import morris # type: ignore[import-untyped] | ||||||||||
| from SALib.sample.morris import sample # type: ignore[import-untyped] | ||||||||||
| from SALib.util import compute_groups_matrix, read_param_file # type: ignore[import-untyped] | ||||||||||
|
|
||||||||||
| from temoa._internal import run_actions | ||||||||||
| from temoa._internal.table_writer import TableWriter | ||||||||||
| from temoa.core.config import TemoaConfig | ||||||||||
| from temoa.data_io.hybrid_loader import HybridLoader | ||||||||||
|
|
||||||||||
| start_time = time.time() | ||||||||||
| import sqlite3 | ||||||||||
|
|
||||||||||
| from joblib import Parallel, delayed | ||||||||||
| from numpy import array | ||||||||||
| from SALib.analyze import morris | ||||||||||
| from SALib.sample.morris import sample | ||||||||||
| from SALib.util import compute_groups_matrix, read_param_file | ||||||||||
|
|
||||||||||
| seed = 42 | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def evaluate(param_names, param_values, data: dict, k): | ||||||||||
| def evaluate(param_names: dict[int, list[Any]], param_values: Any, | ||||||||||
| data: dict[str, Any], k: int) -> list[Any]: | ||||||||||
|
Comment on lines
+24
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧹 Nitpick | 🔵 Trivial Unused parameter The ♻️ Proposed fix def evaluate(param_names: dict[int, list[Any]], param_values: Any,
- data: dict[str, Any], k: int) -> list[Any]:
+ data: dict[str, Any], _k: int) -> list[Any]:📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.14.10)24-24: Dynamically typed expressions (typing.Any) are disallowed in (ANN401) 25-25: Unused function argument: (ARG001) 🤖 Prompt for AI Agents |
||||||||||
| m = len(param_values) | ||||||||||
|
Comment on lines
+24
to
26
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical:
Proposed fix (make dependencies explicit)-def evaluate(param_names: dict[int, list[Any]], param_values: Any,
- data: dict[str, Any], k: int) -> list[Any]:
+def evaluate(
+ param_names: dict[int, list[Any]],
+ param_values: Any,
+ data: dict[str, Any],
+ *,
+ config: TemoaConfig,
+ db_path: str,
+) -> list[Any]:
m = len(param_values)
for j in range(0, m):
names = param_names[j]
@@
- mdl, res = run_actions.solve_instance(instance=instance, solver_name=config.solver_name)
+ mdl, res = run_actions.solve_instance(instance=instance, solver_name=config.solver_name)
@@
- con = sqlite3.connect(db_file)
+ con = sqlite3.connect(db_path)
@@
con.close()
return morris_objectivesYou’ll also need to update the Also applies to: 40-63 🧰 Tools🪛 Ruff (0.14.10)24-24: Dynamically typed expressions (typing.Any) are disallowed in (ANN401) 25-25: Unused function argument: (ARG001) 🤖 Prompt for AI Agents |
||||||||||
| for j in range(0, m): | ||||||||||
| names = param_names[j] | ||||||||||
|
|
@@ -51,16 +51,16 @@ def evaluate(param_names, param_values, data: dict, k): | |||||||||
| cur.execute('SELECT * FROM output_objective') | ||||||||||
| output_query = cur.fetchall() | ||||||||||
| for row in output_query: | ||||||||||
| Y_OF = row[-1] | ||||||||||
| cur.execute("SELECT emis_comm, SUM(emission) FROM output_emissionn WHERE emis_comm='co2'") | ||||||||||
| y_of = row[-1] | ||||||||||
| cur.execute("SELECT emis_comm, SUM(emission) FROM output_emission WHERE emis_comm='co2'") | ||||||||||
| output_query = cur.fetchall() | ||||||||||
| for row in output_query: | ||||||||||
| Y_CumulativeCO2 = row[-1] | ||||||||||
| Morris_Objectives = [] | ||||||||||
| Morris_Objectives.append(Y_OF) | ||||||||||
| Morris_Objectives.append(Y_CumulativeCO2) | ||||||||||
| y_cumulative_co2 = row[-1] | ||||||||||
| morris_objectives = [] | ||||||||||
| morris_objectives.append(y_of) | ||||||||||
| morris_objectives.append(y_cumulative_co2) | ||||||||||
| con.close() | ||||||||||
| return Morris_Objectives | ||||||||||
| return morris_objectives | ||||||||||
|
|
||||||||||
|
|
||||||||||
| morris_root = Path(__file__).parent | ||||||||||
|
|
@@ -137,80 +137,76 @@ def evaluate(param_names, param_values, data: dict, k): | |||||||||
| file.write('\n') | ||||||||||
|
|
||||||||||
| # load a data portal, retrieve the data dict for the problem | ||||||||||
| config = TemoaConfig.build_config(config_file=config_path, output_path='.') | ||||||||||
| config = TemoaConfig.build_config(config_file=config_path, output_path=Path('.')) | ||||||||||
| loader = HybridLoader(db_connection=con, config=config) | ||||||||||
| loader.load_data_portal() | ||||||||||
| data = loader.data | ||||||||||
|
|
||||||||||
| problem = read_param_file(str(param_file), delimiter=' ') | ||||||||||
| param_values = sample( | ||||||||||
| problem, N=10, num_levels=8, optimal_trajectories=False, local_optimization=False, seed=seed | ||||||||||
| ) | ||||||||||
| print(param_values) | ||||||||||
| print(param_names) | ||||||||||
| n = len(param_values) | ||||||||||
| # pull the data | ||||||||||
|
|
||||||||||
| num_cores = 1 # multiprocessing.cpu_count() | ||||||||||
| Morris_Objectives = Parallel(n_jobs=num_cores)( | ||||||||||
| delayed(evaluate)(param_names, param_values[i, :], data, i) for i in range(0, n) | ||||||||||
| ) | ||||||||||
| Morris_Objectives = array(Morris_Objectives) | ||||||||||
| print(Morris_Objectives) | ||||||||||
| Si_OF = morris.analyze( | ||||||||||
| problem, | ||||||||||
| param_values, | ||||||||||
| Morris_Objectives[:, 0], | ||||||||||
| conf_level=0.95, | ||||||||||
| print_to_console=False, | ||||||||||
| num_levels=8, | ||||||||||
| num_resamples=1000, | ||||||||||
| seed=seed + 1, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| Si_CumulativeCO2 = morris.analyze( | ||||||||||
| problem, | ||||||||||
| param_values, | ||||||||||
| Morris_Objectives[:, 1], | ||||||||||
| conf_level=0.95, | ||||||||||
| print_to_console=False, | ||||||||||
| num_levels=8, | ||||||||||
| num_resamples=1000, | ||||||||||
| seed=seed + 2, | ||||||||||
| ) | ||||||||||
| num_vars = problem['num_vars'] | ||||||||||
| groups, unique_group_names = compute_groups_matrix(problem['groups']) | ||||||||||
| number_of_groups = len(unique_group_names) | ||||||||||
| print( | ||||||||||
| '{:<30} {:>10} {:>10} {:>15} {:>10}'.format( | ||||||||||
| 'Parameter', 'Mu_Star', 'Mu', 'Mu_Star_Conf', 'Sigma' | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| for j in list(range(number_of_groups)): | ||||||||||
| print( | ||||||||||
| '{:30} {:10.3f} {:10.3f} {:15.3f} {:10.3f}'.format( | ||||||||||
| Si_OF['names'][j], | ||||||||||
| Si_OF['mu_star'][j], | ||||||||||
| Si_OF['mu'][j], | ||||||||||
| Si_OF['mu_star_conf'][j], | ||||||||||
| Si_OF['sigma'][j], | ||||||||||
| problem = read_param_file(str(param_file), delimiter=' ') | ||||||||||
| param_values = sample( | ||||||||||
| problem, N=10, num_levels=8, optimal_trajectories=False, local_optimization=False, seed=seed | ||||||||||
| ) | ||||||||||
| print(param_values) | ||||||||||
| print(param_names) | ||||||||||
| n = len(param_values) | ||||||||||
|
|
||||||||||
| # pull the data | ||||||||||
| num_cores = 1 # multiprocessing.cpu_count() | ||||||||||
| Morris_Objectives = Parallel(n_jobs=num_cores)( | ||||||||||
| delayed(evaluate)(param_names, param_values[i, :], data, i) for i in range(0, n) | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| import csv | ||||||||||
|
|
||||||||||
| line1 = Si_OF['mu_star'] | ||||||||||
| line2 = Si_OF['mu_star_conf'] | ||||||||||
| line3 = Si_CumulativeCO2['mu_star'] | ||||||||||
| line4 = Si_CumulativeCO2['mu_star_conf'] | ||||||||||
| with open('MMResults.csv', 'w') as f: | ||||||||||
| writer = csv.writer(f, delimiter=',') | ||||||||||
| writer.writerow(unique_group_names) | ||||||||||
| writer.writerow('Objective Function') | ||||||||||
| writer.writerow(line1) | ||||||||||
| writer.writerow(line2) | ||||||||||
| writer.writerow('Cumulative CO2 Emissions') | ||||||||||
| writer.writerow(line3) | ||||||||||
| writer.writerow(line4) | ||||||||||
|
|
||||||||||
| f.close | ||||||||||
| print('--- %s seconds ---' % (time.time() - start_time)) | ||||||||||
| Morris_Objectives = array(Morris_Objectives) | ||||||||||
|
Comment on lines
+155
to
+159
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧹 Nitpick | 🔵 Trivial Inconsistent variable naming: Other variables in this file were renamed to snake_case ( Proposed fix- Morris_Objectives = Parallel(n_jobs=num_cores)(
+ morris_objectives = Parallel(n_jobs=num_cores)(
delayed(evaluate)(param_names, param_values[i, :], data, i) for i in range(0, n)
)
- Morris_Objectives = array(Morris_Objectives)
- print(Morris_Objectives)
+ morris_objectives = array(morris_objectives)
+ print(morris_objectives)
si_of = morris.analyze(
problem,
param_values,
- Morris_Objectives[:, 0],
+ morris_objectives[:, 0],🧰 Tools🪛 Ruff (0.15.6)[warning] 156-156: Unnecessary Remove (PIE808) 🤖 Prompt for AI Agents |
||||||||||
| print(Morris_Objectives) | ||||||||||
| si_of = morris.analyze( | ||||||||||
| problem, | ||||||||||
| param_values, | ||||||||||
| Morris_Objectives[:, 0], | ||||||||||
| conf_level=0.95, | ||||||||||
| print_to_console=False, | ||||||||||
| num_levels=8, | ||||||||||
| num_resamples=1000, | ||||||||||
| seed=seed + 1, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| si_cumulative_co2 = morris.analyze( | ||||||||||
| problem, | ||||||||||
| param_values, | ||||||||||
| Morris_Objectives[:, 1], | ||||||||||
| conf_level=0.95, | ||||||||||
| print_to_console=False, | ||||||||||
| num_levels=8, | ||||||||||
| num_resamples=1000, | ||||||||||
| seed=seed + 2, | ||||||||||
| ) | ||||||||||
| groups, unique_group_names = compute_groups_matrix(problem['groups']) | ||||||||||
| number_of_groups = len(unique_group_names) | ||||||||||
| print( | ||||||||||
| '{:<30} {:>10} {:>10} {:>15} {:>10}'.format( | ||||||||||
| 'Parameter', 'Mu_Star', 'Mu', 'Mu_Star_Conf', 'Sigma' | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
| for j in list(range(number_of_groups)): | ||||||||||
| print( | ||||||||||
| '{:30} {:10.3f} {:10.3f} {:15.3f} {:10.3f}'.format( | ||||||||||
| si_of['names'][j], | ||||||||||
| si_of['mu_star'][j], | ||||||||||
| si_of['mu'][j], | ||||||||||
| si_of['mu_star_conf'][j], | ||||||||||
| si_of['sigma'][j], | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| line1 = si_of['mu_star'] | ||||||||||
| line2 = si_of['mu_star_conf'] | ||||||||||
| line3 = si_cumulative_co2['mu_star'] | ||||||||||
| line4 = si_cumulative_co2['mu_star_conf'] | ||||||||||
| with open('MMResults.csv', 'w') as f: | ||||||||||
| writer = csv.writer(f, delimiter=',') | ||||||||||
| writer.writerow(unique_group_names) | ||||||||||
| writer.writerow('Objective Function') | ||||||||||
| writer.writerow(line1) | ||||||||||
| writer.writerow(line2) | ||||||||||
| writer.writerow('Cumulative CO2 Emissions') | ||||||||||
| writer.writerow(line3) | ||||||||||
| writer.writerow(line4) | ||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tighten the mypy exclude regex to match
breakevenboth with and without trailing slash.Current regex only matches
^temoa/extensions/breakeven/…paths.Proposed tweak
🤖 Prompt for AI Agents