Skip to content

Commit 5fe0683

Browse files
author
Jairo Llopis
committed
Use PyInquirer
1 parent de95399 commit 5fe0683

File tree

7 files changed

+356
-223
lines changed

7 files changed

+356
-223
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ repos:
1111
# hooks running from local virtual environment
1212
- repo: local
1313
hooks:
14+
- id: autoflake
15+
name: autoflake
16+
entry: poetry run autoflake
17+
language: system
18+
types: [python]
19+
args: ["-i", "--remove-all-unused-imports", "--ignore-init-module-imports"]
1420
- id: black
1521
name: black
1622
entry: poetry run black

copier/config/objects.py

Lines changed: 246 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,34 @@
11
"""Pydantic models, exceptions and default values."""
2-
32
import datetime
3+
import json
44
from collections import ChainMap
5+
from contextlib import suppress
56
from copy import deepcopy
67
from hashlib import sha512
78
from os import urandom
89
from pathlib import Path
9-
from typing import Any, ChainMap as t_ChainMap, Sequence, Tuple, Union
10+
from typing import (
11+
Any,
12+
Callable,
13+
ChainMap as t_ChainMap,
14+
Dict,
15+
Iterable,
16+
List,
17+
Optional,
18+
Sequence,
19+
Tuple,
20+
Union,
21+
)
1022

11-
from pydantic import BaseModel, Extra, StrictBool, validator
23+
from jinja2 import UndefinedError
24+
from jinja2.sandbox import SandboxedEnvironment
25+
from prompt_toolkit.lexers import PygmentsLexer
26+
from prompt_toolkit.validation import Validator
27+
from pydantic import BaseModel, Extra, Field, StrictBool, validator
28+
from pygments.lexers.data import JsonLexer, YamlLexer
29+
from PyInquirer.prompt import prompt
1230

31+
from ..tools import cast_answer_type, force_str_end, parse_yaml_string
1332
from ..types import AnyByStrDict, OptStr, PathSeq, StrOrPathSeq, StrSeq
1433

1534
# Default list of files in the template to exclude from the rendered project
@@ -31,12 +50,19 @@
3150

3251
DEFAULT_TEMPLATES_SUFFIX = ".tmpl"
3352

53+
TYPE_COPIER2PYTHON: Dict[str, Callable] = {
54+
"bool": bool,
55+
"float": float,
56+
"int": int,
57+
"json": json.loads,
58+
"str": str,
59+
"yaml": parse_yaml_string,
60+
}
61+
3462

3563
class UserMessageError(Exception):
3664
"""Exit the program giving a message to the user."""
3765

38-
pass
39-
4066

4167
class NoSrcPathError(UserMessageError):
4268
pass
@@ -152,3 +178,218 @@ def data(self) -> t_ChainMap[str, Any]:
152178
class Config:
153179
allow_mutation = False
154180
anystr_strip_whitespace = True
181+
182+
183+
class Question(BaseModel):
184+
choices: Union[Dict[Any, Any], List[Any]] = Field(default_factory=list)
185+
default: Any = None
186+
help_text: str = ""
187+
multiline: Optional[bool] = None
188+
placeholder: str = ""
189+
questionary: "Questionary"
190+
secret: bool = False
191+
type_name: str = ""
192+
var_name: str
193+
when: Union[str, bool] = True
194+
195+
class Config:
196+
arbitrary_types_allowed = True
197+
198+
def __init__(self, **kwargs):
199+
# Transform arguments that are named like python keywords
200+
to_rename = (("help", "help_text"), ("type", "type_name"))
201+
for from_, to in to_rename:
202+
with suppress(KeyError):
203+
kwargs.setdefault(to, kwargs.pop(from_))
204+
# Infer type from default if missing
205+
super().__init__(**kwargs)
206+
self.questionary.questions.append(self)
207+
208+
def __repr__(self):
209+
return f"Question({self.var_name})"
210+
211+
@validator("var_name")
212+
def _check_var_name(cls, v):
213+
if v in DEFAULT_DATA:
214+
raise ValueError("Invalid question name")
215+
return v
216+
217+
@validator("type_name", always=True)
218+
def _check_type_name(cls, v, values):
219+
if v == "":
220+
default_type_name = type(values.get("default")).__name__
221+
v = default_type_name if default_type_name in TYPE_COPIER2PYTHON else "yaml"
222+
if v not in TYPE_COPIER2PYTHON:
223+
raise ValueError("Invalid question type")
224+
return v
225+
226+
def _iter_choices(self) -> Iterable[dict]:
227+
choices = self.choices
228+
if isinstance(self.choices, dict):
229+
choices = list(self.choices.items())
230+
for choice in choices:
231+
# If a choice is a dict, it can be used raw
232+
if isinstance(choice, dict):
233+
yield choice
234+
continue
235+
# However, a choice can also be a single value...
236+
name = value = choice
237+
# ... or a value pair
238+
if isinstance(choice, (tuple, list)):
239+
name, value = choice
240+
# The name must always be a str
241+
name = str(name)
242+
yield {"name": name, "value": value}
243+
244+
def get_default(self, autocast: bool) -> Any:
245+
try:
246+
result = self.questionary.answers_forced.get(
247+
self.var_name, self.questionary.answers_last[self.var_name]
248+
)
249+
except KeyError:
250+
result = self.render_value(self.default)
251+
result = cast_answer_type(result, self.get_type_fn())
252+
if not autocast:
253+
return result
254+
if self.type_name == "bool":
255+
return bool(result)
256+
if result is None:
257+
return ""
258+
return str(result)
259+
260+
def get_choices(self) -> List[AnyByStrDict]:
261+
result = []
262+
for choice in self._iter_choices():
263+
formatted_choice = {
264+
key: self.render_value(value) for key, value in choice.items()
265+
}
266+
result.append(formatted_choice)
267+
return result
268+
269+
def get_filter(self, answer) -> Any:
270+
if answer == self.get_default(autocast=True):
271+
return self.get_default()
272+
return cast_answer_type(answer, self.get_type_fn())
273+
274+
def get_message(self) -> str:
275+
message = ""
276+
if self.help_text:
277+
rendered_help = self.render_value(self.help_text)
278+
message = force_str_end(rendered_help)
279+
message += f"{self.var_name}? Format: {self.type_name}"
280+
return message
281+
282+
def get_placeholder(self) -> str:
283+
return self.render_value(self.placeholder)
284+
285+
def get_pyinquirer_structure(self):
286+
lexer = None
287+
result = {
288+
"default": self.get_default(autocast=True),
289+
"filter": self.get_filter,
290+
"message": self.get_message(),
291+
"mouse_support": True,
292+
"name": self.var_name,
293+
"qmark": "🕵️" if self.secret else "🎤",
294+
"validator": Validator.from_callable(self.get_validator),
295+
"when": self.get_when,
296+
}
297+
multiline = self.multiline
298+
pyinquirer_type = "input"
299+
if self.type_name == "bool":
300+
pyinquirer_type = "confirm"
301+
if self.choices:
302+
pyinquirer_type = "list"
303+
result["choices"] = self.get_choices()
304+
if pyinquirer_type == "input":
305+
if self.secret:
306+
pyinquirer_type = "password"
307+
elif self.type_name == "yaml":
308+
lexer = PygmentsLexer(YamlLexer)
309+
elif self.type_name == "json":
310+
lexer = PygmentsLexer(JsonLexer)
311+
placeholder = self.get_placeholder()
312+
if placeholder:
313+
result["placeholder"] = placeholder
314+
multiline = multiline or (
315+
multiline is None and self.type_name in {"yaml", "json"}
316+
)
317+
result.update({"type": pyinquirer_type, "lexer": lexer, "multiline": multiline})
318+
return result
319+
320+
def get_type_fn(self) -> Callable:
321+
return TYPE_COPIER2PYTHON.get(self.type_name, parse_yaml_string)
322+
323+
def get_validator(self, document) -> bool:
324+
type_fn = self.get_type_fn()
325+
try:
326+
type_fn(document)
327+
return True
328+
except Exception:
329+
return False
330+
331+
def get_when(self, answers) -> bool:
332+
if (
333+
# Skip on --force
334+
not self.questionary.ask_user
335+
# Skip on --data=this_question=some_answer
336+
or self.var_name in self.questionary.answers_forced
337+
):
338+
return False
339+
when = self.when
340+
when = self.render_value(when)
341+
when = cast_answer_type(when, parse_yaml_string)
342+
return bool(when)
343+
344+
def render_value(self, value: Any) -> str:
345+
"""Render a single templated value using Jinja.
346+
347+
If the value cannot be used as a template, it will be returned as is.
348+
"""
349+
try:
350+
template = self.questionary.env.from_string(value)
351+
except TypeError:
352+
# value was not a string
353+
return value
354+
try:
355+
return template.render(**self.questionary.get_best_answers())
356+
except UndefinedError as error:
357+
raise UserMessageError(str(error)) from error
358+
359+
360+
class Questionary(BaseModel):
361+
answers_forced: AnyByStrDict = Field(default_factory=dict)
362+
answers_last: AnyByStrDict = Field(default_factory=dict)
363+
answers_user: AnyByStrDict = Field(default_factory=dict)
364+
ask_user: bool = True
365+
env: SandboxedEnvironment
366+
questions: List[Question] = Field(default_factory=list)
367+
368+
class Config:
369+
arbitrary_types_allowed = True
370+
371+
def __init__(self, **kwargs):
372+
super().__init__(**kwargs)
373+
374+
def get_best_answers(self) -> t_ChainMap[str, Any]:
375+
return ChainMap(self.answers_user, self.answers_last, self.answers_forced)
376+
377+
def get_answers(self) -> AnyByStrDict:
378+
if self.ask_user:
379+
prompt(
380+
(question.get_pyinquirer_structure() for question in self.questions),
381+
answers=self.answers_user,
382+
raise_keyboard_interrupt=True,
383+
)
384+
else:
385+
# Avoid prompting to not requiring a TTy when --force
386+
self.answers_user.update(
387+
{
388+
question.var_name: question.get_default(autocast=False)
389+
for question in self.questions
390+
}
391+
)
392+
return self.answers_user
393+
394+
395+
Question.update_forward_refs()

0 commit comments

Comments
 (0)