22import sys
33import json
44import shlex
5- from typing import Any
65from pathlib import Path
76from logging import Logger
87from functools import partial
8+ from typing_extensions import Required
9+ from typing import TypeAlias , TypedDict
910from dataclasses import field , dataclass
11+ from collections .abc import Sequence , MutableMapping
1012
1113import click
1214import nonestorage
5052 "bootstrap" : _ ("bootstrap (for beginner or user)" ),
5153 "simple" : _ ("simple (for plugin developer)" ),
5254}
55+ HIDDEN_FILE_OVERRIDES = {".env" , ".env.dev" , ".env.prod" , ".gitignore" , ".vscode" }
56+ SerializedJSON : TypeAlias = str
5357
5458BLACKLISTED_PROJECT_NAME .update (sys .stdlib_module_names )
5559
5660
61+ class ProjectTemplateProps (TypedDict ):
62+ """项目模板渲染变量字典集"""
63+
64+ project_name : Required [str ]
65+ inplace : bool
66+ adapters : SerializedJSON
67+ drivers : SerializedJSON
68+ environment : MutableMapping [str , str ]
69+ use_src : bool
70+ devtools : Sequence [str ]
71+
72+
5773@dataclass
5874class ProjectContext :
5975 """项目模板生成上下文
@@ -63,12 +79,14 @@ class ProjectContext:
6379 packages: 项目需要安装的包
6480 """
6581
66- variables : dict [str , Any ] = field (default_factory = dict )
82+ variables : ProjectTemplateProps = field ( # pyright: ignore[reportAssignmentType]
83+ default_factory = dict
84+ )
6785 packages : list [str ] = field (default_factory = list )
6886
6987
7088def project_name_validator (name : str ) -> bool :
71- return (
89+ return name == "." or (
7290 bool (re .match (VALID_PROJECT_NAME , name ))
7391 and name not in BLACKLISTED_PROJECT_NAME
7492 )
@@ -92,6 +110,25 @@ async def prompt_common_context(context: ProjectContext) -> ProjectContext:
92110 error_message = _ ("Invalid project name!" ),
93111 ).prompt_async (style = CLI_DEFAULT_STYLE )
94112 context .variables ["project_name" ] = project_name
113+ context .variables ["inplace" ] = False
114+
115+ if project_name == "." :
116+ _parent_dirname = Path ("." ).absolute ().name
117+ if not project_name_validator (_parent_dirname ):
118+ click .secho (_ ("Invalid project name!" ), fg = "red" )
119+ raise CancelledError
120+ if any (
121+ (f .name in HIDDEN_FILE_OVERRIDES or not f .name .startswith ("." ))
122+ for f in Path (project_name ).iterdir ()
123+ ):
124+ if not await ConfirmPrompt (
125+ _ ("Current folder is not empty. Overwrite existing files?" ),
126+ False ,
127+ ).prompt_async (style = CLI_DEFAULT_STYLE ):
128+ click .echo (_ ("Stopped creating bot." ))
129+ raise CancelledError
130+ project_name = context .variables ["project_name" ] = _parent_dirname
131+ context .variables ["inplace" ] = True
95132
96133 confirm = False
97134 adapters = []
@@ -297,7 +334,9 @@ async def create(
297334 use_venv = False
298335 project_dir_name = context .variables ["project_name" ].replace (" " , "-" )
299336 project_dir = Path (output_dir or "." ) / project_dir_name
300- venv_dir = project_dir / ".venv"
337+ venv_dir = (
338+ Path ("./.venv" ) if context .variables ["inplace" ] else project_dir / ".venv"
339+ )
301340
302341 if install_dependencies :
303342 try :
0 commit comments