|
1 | 1 | """Pydantic models, exceptions and default values."""
|
2 |
| - |
3 | 2 | import datetime
|
| 3 | +import json |
4 | 4 | from collections import ChainMap
|
| 5 | +from contextlib import suppress |
5 | 6 | from copy import deepcopy
|
6 | 7 | from hashlib import sha512
|
7 | 8 | from os import urandom
|
8 | 9 | from pathlib import Path
|
9 |
| -from typing import Any, ChainMap as t_ChainMap, Sequence, Tuple, Union |
| 10 | +from typing import ( |
| 11 | + Any, |
| 12 | + Callable, |
| 13 | + ChainMap as t_ChainMap, |
| 14 | + Dict, |
| 15 | + Iterable, |
| 16 | + List, |
| 17 | + Optional, |
| 18 | + Sequence, |
| 19 | + Tuple, |
| 20 | + Union, |
| 21 | +) |
10 | 22 |
|
11 |
| -from pydantic import BaseModel, Extra, StrictBool, validator |
| 23 | +from jinja2 import UndefinedError |
| 24 | +from jinja2.sandbox import SandboxedEnvironment |
| 25 | +from prompt_toolkit.lexers import PygmentsLexer |
| 26 | +from prompt_toolkit.validation import Validator |
| 27 | +from pydantic import BaseModel, Extra, Field, StrictBool, validator |
| 28 | +from pygments.lexers.data import JsonLexer, YamlLexer |
| 29 | +from PyInquirer.prompt import prompt |
12 | 30 |
|
| 31 | +from ..tools import cast_answer_type, force_str_end, parse_yaml_string |
13 | 32 | from ..types import AnyByStrDict, OptStr, PathSeq, StrOrPathSeq, StrSeq
|
14 | 33 |
|
15 | 34 | # Default list of files in the template to exclude from the rendered project
|
|
31 | 50 |
|
32 | 51 | DEFAULT_TEMPLATES_SUFFIX = ".tmpl"
|
33 | 52 |
|
| 53 | +TYPE_COPIER2PYTHON: Dict[str, Callable] = { |
| 54 | + "bool": bool, |
| 55 | + "float": float, |
| 56 | + "int": int, |
| 57 | + "json": json.loads, |
| 58 | + "str": str, |
| 59 | + "yaml": parse_yaml_string, |
| 60 | +} |
| 61 | + |
34 | 62 |
|
35 | 63 | class UserMessageError(Exception):
|
36 | 64 | """Exit the program giving a message to the user."""
|
37 | 65 |
|
38 |
| - pass |
39 |
| - |
40 | 66 |
|
41 | 67 | class NoSrcPathError(UserMessageError):
|
42 | 68 | pass
|
@@ -152,3 +178,218 @@ def data(self) -> t_ChainMap[str, Any]:
|
152 | 178 | class Config:
|
153 | 179 | allow_mutation = False
|
154 | 180 | anystr_strip_whitespace = True
|
| 181 | + |
| 182 | + |
| 183 | +class Question(BaseModel): |
| 184 | + choices: Union[Dict[Any, Any], List[Any]] = Field(default_factory=list) |
| 185 | + default: Any = None |
| 186 | + help_text: str = "" |
| 187 | + multiline: Optional[bool] = None |
| 188 | + placeholder: str = "" |
| 189 | + questionary: "Questionary" |
| 190 | + secret: bool = False |
| 191 | + type_name: str = "" |
| 192 | + var_name: str |
| 193 | + when: Union[str, bool] = True |
| 194 | + |
| 195 | + class Config: |
| 196 | + arbitrary_types_allowed = True |
| 197 | + |
| 198 | + def __init__(self, **kwargs): |
| 199 | + # Transform arguments that are named like python keywords |
| 200 | + to_rename = (("help", "help_text"), ("type", "type_name")) |
| 201 | + for from_, to in to_rename: |
| 202 | + with suppress(KeyError): |
| 203 | + kwargs.setdefault(to, kwargs.pop(from_)) |
| 204 | + # Infer type from default if missing |
| 205 | + super().__init__(**kwargs) |
| 206 | + self.questionary.questions.append(self) |
| 207 | + |
| 208 | + def __repr__(self): |
| 209 | + return f"Question({self.var_name})" |
| 210 | + |
| 211 | + @validator("var_name") |
| 212 | + def _check_var_name(cls, v): |
| 213 | + if v in DEFAULT_DATA: |
| 214 | + raise ValueError("Invalid question name") |
| 215 | + return v |
| 216 | + |
| 217 | + @validator("type_name", always=True) |
| 218 | + def _check_type_name(cls, v, values): |
| 219 | + if v == "": |
| 220 | + default_type_name = type(values.get("default")).__name__ |
| 221 | + v = default_type_name if default_type_name in TYPE_COPIER2PYTHON else "yaml" |
| 222 | + if v not in TYPE_COPIER2PYTHON: |
| 223 | + raise ValueError("Invalid question type") |
| 224 | + return v |
| 225 | + |
| 226 | + def _iter_choices(self) -> Iterable[dict]: |
| 227 | + choices = self.choices |
| 228 | + if isinstance(self.choices, dict): |
| 229 | + choices = list(self.choices.items()) |
| 230 | + for choice in choices: |
| 231 | + # If a choice is a dict, it can be used raw |
| 232 | + if isinstance(choice, dict): |
| 233 | + yield choice |
| 234 | + continue |
| 235 | + # However, a choice can also be a single value... |
| 236 | + name = value = choice |
| 237 | + # ... or a value pair |
| 238 | + if isinstance(choice, (tuple, list)): |
| 239 | + name, value = choice |
| 240 | + # The name must always be a str |
| 241 | + name = str(name) |
| 242 | + yield {"name": name, "value": value} |
| 243 | + |
| 244 | + def get_default(self, autocast: bool) -> Any: |
| 245 | + try: |
| 246 | + result = self.questionary.answers_forced.get( |
| 247 | + self.var_name, self.questionary.answers_last[self.var_name] |
| 248 | + ) |
| 249 | + except KeyError: |
| 250 | + result = self.render_value(self.default) |
| 251 | + result = cast_answer_type(result, self.get_type_fn()) |
| 252 | + if not autocast: |
| 253 | + return result |
| 254 | + if self.type_name == "bool": |
| 255 | + return bool(result) |
| 256 | + if result is None: |
| 257 | + return "" |
| 258 | + return str(result) |
| 259 | + |
| 260 | + def get_choices(self) -> List[AnyByStrDict]: |
| 261 | + result = [] |
| 262 | + for choice in self._iter_choices(): |
| 263 | + formatted_choice = { |
| 264 | + key: self.render_value(value) for key, value in choice.items() |
| 265 | + } |
| 266 | + result.append(formatted_choice) |
| 267 | + return result |
| 268 | + |
| 269 | + def get_filter(self, answer) -> Any: |
| 270 | + if answer == self.get_default(autocast=True): |
| 271 | + return self.get_default() |
| 272 | + return cast_answer_type(answer, self.get_type_fn()) |
| 273 | + |
| 274 | + def get_message(self) -> str: |
| 275 | + message = "" |
| 276 | + if self.help_text: |
| 277 | + rendered_help = self.render_value(self.help_text) |
| 278 | + message = force_str_end(rendered_help) |
| 279 | + message += f"{self.var_name}? Format: {self.type_name}" |
| 280 | + return message |
| 281 | + |
| 282 | + def get_placeholder(self) -> str: |
| 283 | + return self.render_value(self.placeholder) |
| 284 | + |
| 285 | + def get_pyinquirer_structure(self): |
| 286 | + lexer = None |
| 287 | + result = { |
| 288 | + "default": self.get_default(autocast=True), |
| 289 | + "filter": self.get_filter, |
| 290 | + "message": self.get_message(), |
| 291 | + "mouse_support": True, |
| 292 | + "name": self.var_name, |
| 293 | + "qmark": "🕵️" if self.secret else "🎤", |
| 294 | + "validator": Validator.from_callable(self.get_validator), |
| 295 | + "when": self.get_when, |
| 296 | + } |
| 297 | + multiline = self.multiline |
| 298 | + pyinquirer_type = "input" |
| 299 | + if self.type_name == "bool": |
| 300 | + pyinquirer_type = "confirm" |
| 301 | + if self.choices: |
| 302 | + pyinquirer_type = "list" |
| 303 | + result["choices"] = self.get_choices() |
| 304 | + if pyinquirer_type == "input": |
| 305 | + if self.secret: |
| 306 | + pyinquirer_type = "password" |
| 307 | + elif self.type_name == "yaml": |
| 308 | + lexer = PygmentsLexer(YamlLexer) |
| 309 | + elif self.type_name == "json": |
| 310 | + lexer = PygmentsLexer(JsonLexer) |
| 311 | + placeholder = self.get_placeholder() |
| 312 | + if placeholder: |
| 313 | + result["placeholder"] = placeholder |
| 314 | + multiline = multiline or ( |
| 315 | + multiline is None and self.type_name in {"yaml", "json"} |
| 316 | + ) |
| 317 | + result.update({"type": pyinquirer_type, "lexer": lexer, "multiline": multiline}) |
| 318 | + return result |
| 319 | + |
| 320 | + def get_type_fn(self) -> Callable: |
| 321 | + return TYPE_COPIER2PYTHON.get(self.type_name, parse_yaml_string) |
| 322 | + |
| 323 | + def get_validator(self, document) -> bool: |
| 324 | + type_fn = self.get_type_fn() |
| 325 | + try: |
| 326 | + type_fn(document) |
| 327 | + return True |
| 328 | + except Exception: |
| 329 | + return False |
| 330 | + |
| 331 | + def get_when(self, answers) -> bool: |
| 332 | + if ( |
| 333 | + # Skip on --force |
| 334 | + not self.questionary.ask_user |
| 335 | + # Skip on --data=this_question=some_answer |
| 336 | + or self.var_name in self.questionary.answers_forced |
| 337 | + ): |
| 338 | + return False |
| 339 | + when = self.when |
| 340 | + when = self.render_value(when) |
| 341 | + when = cast_answer_type(when, parse_yaml_string) |
| 342 | + return bool(when) |
| 343 | + |
| 344 | + def render_value(self, value: Any) -> str: |
| 345 | + """Render a single templated value using Jinja. |
| 346 | +
|
| 347 | + If the value cannot be used as a template, it will be returned as is. |
| 348 | + """ |
| 349 | + try: |
| 350 | + template = self.questionary.env.from_string(value) |
| 351 | + except TypeError: |
| 352 | + # value was not a string |
| 353 | + return value |
| 354 | + try: |
| 355 | + return template.render(**self.questionary.get_best_answers()) |
| 356 | + except UndefinedError as error: |
| 357 | + raise UserMessageError(str(error)) from error |
| 358 | + |
| 359 | + |
| 360 | +class Questionary(BaseModel): |
| 361 | + answers_forced: AnyByStrDict = Field(default_factory=dict) |
| 362 | + answers_last: AnyByStrDict = Field(default_factory=dict) |
| 363 | + answers_user: AnyByStrDict = Field(default_factory=dict) |
| 364 | + ask_user: bool = True |
| 365 | + env: SandboxedEnvironment |
| 366 | + questions: List[Question] = Field(default_factory=list) |
| 367 | + |
| 368 | + class Config: |
| 369 | + arbitrary_types_allowed = True |
| 370 | + |
| 371 | + def __init__(self, **kwargs): |
| 372 | + super().__init__(**kwargs) |
| 373 | + |
| 374 | + def get_best_answers(self) -> t_ChainMap[str, Any]: |
| 375 | + return ChainMap(self.answers_user, self.answers_last, self.answers_forced) |
| 376 | + |
| 377 | + def get_answers(self) -> AnyByStrDict: |
| 378 | + if self.ask_user: |
| 379 | + prompt( |
| 380 | + (question.get_pyinquirer_structure() for question in self.questions), |
| 381 | + answers=self.answers_user, |
| 382 | + raise_keyboard_interrupt=True, |
| 383 | + ) |
| 384 | + else: |
| 385 | + # Avoid prompting to not requiring a TTy when --force |
| 386 | + self.answers_user.update( |
| 387 | + { |
| 388 | + question.var_name: question.get_default(autocast=False) |
| 389 | + for question in self.questions |
| 390 | + } |
| 391 | + ) |
| 392 | + return self.answers_user |
| 393 | + |
| 394 | + |
| 395 | +Question.update_forward_refs() |
0 commit comments