commit 17cd97be119bcf20d8b0c8fac91bfaec2fc5213c Author: szdytom Date: Sat Mar 28 23:13:37 2026 +0800 add initial implementation of simple-ink interactive story framework with YAML support and command line interface Signed-off-by: szdytom diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ff6402d --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +.vscode +__pycache__ +*.pyc +*.pyo +*.pyd +*.db +saves diff --git a/README.md b/README.md new file mode 100644 index 0000000..536b5b3 --- /dev/null +++ b/README.md @@ -0,0 +1,86 @@ +# simple-ink + +一个用 Python 编写的命令行交互式小说框架: +- 从 YAML 配置加载剧情图 +- 节点文本展示 + 选项分支 +- 条件过滤与状态变化(简化 DSL) +- 支持存档/读档 + +## 快速开始 + +1. 安装依赖: + +```bash +pip install -r requirements.txt +``` + +2. 运行示例故事: + +```bash +python main.py --story data/main_story.yaml +``` + +## 配置结构 + +```yaml +story_id: demo_story +start: intro +nodes: + - id: intro + text: 你的开场文本 + effects: + - hp = 10 + options: + - text: 前进 + target: next_node + condition: hp > 0 + effects: + - hp -= 1 +``` + +字段说明: +- `story_id`: 故事标识 +- `start`: 起始节点 ID +- `nodes`: 节点列表 +- `node.id`: 节点唯一 ID +- `node.text`: 显示文本 +- `node.options`: 选项列表 +- `node.effects`: 进入该节点时执行的状态变化 +- `node.end`: `true` 表示终局 +- `option.text`: 选项文本 +- `option.target`: 目标节点 ID +- `option.condition`: 选项显示条件 +- `option.effects`: 选择该选项后执行的状态变化 + +## DSL 语法 + +### condition +支持: +- 比较:`== != > >= < <=` +- 逻辑:`and or not` +- 变量名直接引用:`has_key`(不存在变量视为 `False`) +- 例子: + - `hp > 0 and not is_cursed` + - `coins >= 1` + - `"lamp" in flags` + +### effect +支持: +- 赋值:`hp = 10`、`is_cursed = false` +- 数值增减:`hp += 2`、`coins -= 1` +- 标记集合(`flags`):`flags += lamp`、`flags -= lamp` + +## 交互命令 + +游戏运行时可输入: +- `:help` 显示命令 +- `:state` 显示当前状态 +- `:save [路径]` 保存进度(默认 `saves/latest.json`) +- `:load [路径]` 读取进度(默认 `saves/latest.json`) +- `:quit` 退出 + +## 测试 + +```bash +pytest -q +``` diff --git a/data/main_story.yaml b/data/main_story.yaml new file mode 100644 index 0000000..c107967 --- /dev/null +++ b/data/main_story.yaml @@ -0,0 +1,52 @@ +story_id: demo_story +start: intro +nodes: + - id: intro + text: | + 你在雨夜醒来,口袋里只有一枚旧硬币。 + 前方有一扇铁门和一条小巷。 + effects: + - coins = 1 + options: + - text: 推开铁门 + target: gate + - text: 走进小巷 + target: alley + + - id: alley + text: 你在小巷尽头发现一盏油灯。 + options: + - text: 拿起油灯 + target: gate + effects: + - flags += lamp + - text: 原路返回 + target: intro + + - id: gate + text: 铁门上有投币孔,门旁贴着模糊告示。 + options: + - text: 投入硬币开门 + target: archive + condition: coins >= 1 + effects: + - coins -= 1 + - text: 用油灯照亮告示 + target: clue + condition: "'lamp' in flags" + - text: 放弃并离开 + target: bad_end + + - id: clue + text: 告示写着:真正的门在你身后。 + options: + - text: 回头寻找暗门 + target: archive + + - id: archive + text: 你进入档案室,找到了真相。 + end: true + + - id: bad_end + text: 你转身离开,故事在雨中结束。 + end: true diff --git a/main.py b/main.py new file mode 100644 index 0000000..ab1ce94 --- /dev/null +++ b/main.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import argparse +from pathlib import Path + +from src.engine import InteractiveRunner, StoryEngine +from src.parser import StoryValidationError, parse_story + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="simple-ink: 命令行交互式小说框架") + parser.add_argument( + "--story", + type=Path, + default=Path("data/main_story.yaml"), + help="故事配置文件路径(YAML)", + ) + parser.add_argument( + "--load", + type=Path, + default=None, + help="启动时立即加载存档路径", + ) + return parser + + +def main() -> int: + args = build_arg_parser().parse_args() + + try: + parsed = parse_story(args.story) + except StoryValidationError as exc: + print(f"故事配置错误: {exc}") + return 2 + + for warning in parsed.warnings: + print(f"[warning] {warning}") + + engine = StoryEngine(parsed.story) + if args.load is not None: + try: + engine.load(args.load) + except Exception as exc: + print(f"加载启动存档失败: {exc}") + return 3 + + runner = InteractiveRunner(engine) + runner.run() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0fdbcb0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +PyYAML>=6.0 +pytest>=8.0 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..926283c --- /dev/null +++ b/src/__init__.py @@ -0,0 +1 @@ +"""simple-ink package.""" diff --git a/src/cli.py b/src/cli.py new file mode 100644 index 0000000..1122708 --- /dev/null +++ b/src/cli.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from src.models import Node, Option + + +def render_node(node: Node, visible_options: list[Option]) -> None: + print("\n" + "=" * 60) + print(node.text) + print("=" * 60) + + if not visible_options: + print("\n[没有可选项,故事结束。]") + return + + print("\n可选项:") + for idx, option in enumerate(visible_options, start=1): + print(f" {idx}. {option.text}") + + +def render_help() -> None: + print("\n命令:") + print(" :help 显示帮助") + print(" :state 显示当前状态") + print(" :save [路径] 保存进度,默认 saves/latest.json") + print(" :load [路径] 读取进度,默认 saves/latest.json") + print(" :quit 退出游戏") + + +def read_input() -> str: + return input("\n请输入选项编号或命令: ").strip() diff --git a/src/dsl.py b/src/dsl.py new file mode 100644 index 0000000..bf8d273 --- /dev/null +++ b/src/dsl.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import ast +import re +from typing import Any + +from src.models import State + + +class DslError(ValueError): + """Raised when a condition/effect expression is invalid.""" + + +_ALLOWED_BOOL_OPS = (ast.And, ast.Or) +_ALLOWED_CMP_OPS = ( + ast.Eq, + ast.NotEq, + ast.Gt, + ast.GtE, + ast.Lt, + ast.LtE, + ast.In, + ast.NotIn, +) + + +def _safe_eval_expr(node: ast.AST, state: State) -> Any: + if isinstance(node, ast.Expression): + return _safe_eval_expr(node.body, state) + + if isinstance(node, ast.BoolOp): + if not isinstance(node.op, _ALLOWED_BOOL_OPS): + raise DslError("Only 'and' and 'or' are allowed in conditions") + values = [_safe_eval_expr(v, state) for v in node.values] + return all(values) if isinstance(node.op, ast.And) else any(values) + + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): + return not _safe_eval_expr(node.operand, state) + + if isinstance(node, ast.Compare): + left = _safe_eval_expr(node.left, state) + for op, right_node in zip(node.ops, node.comparators): + if not isinstance(op, _ALLOWED_CMP_OPS): + raise DslError("Unsupported comparator in condition") + right = _safe_eval_expr(right_node, state) + if isinstance(op, ast.Eq): + ok = left == right + elif isinstance(op, ast.NotEq): + ok = left != right + elif isinstance(op, ast.Gt): + try: + ok = left > right + except TypeError as exc: + raise DslError("Type mismatch in '>' comparison") from exc + elif isinstance(op, ast.GtE): + try: + ok = left >= right + except TypeError as exc: + raise DslError("Type mismatch in '>=' comparison") from exc + elif isinstance(op, ast.Lt): + try: + ok = left < right + except TypeError as exc: + raise DslError("Type mismatch in '<' comparison") from exc + elif isinstance(op, ast.In): + try: + ok = left in right + except TypeError: + ok = False + elif isinstance(op, ast.NotIn): + try: + ok = left not in right + except TypeError: + ok = True + else: + try: + ok = left <= right + except TypeError as exc: + raise DslError("Type mismatch in '<=' comparison") from exc + if not ok: + return False + left = right + return True + + if isinstance(node, ast.Name): + return state.get(node.id, 0) + + if isinstance(node, ast.Constant): + return node.value + + raise DslError("Unsupported syntax in condition") + + +def evaluate_condition(condition: str | None, state: State) -> bool: + if not condition: + return True + try: + parsed = ast.parse(condition, mode="eval") + except SyntaxError as exc: + raise DslError(f"Invalid condition syntax: {condition}") from exc + return bool(_safe_eval_expr(parsed, state)) + + +def _parse_literal(value: str) -> Any: + lowered = value.strip().lower() + if lowered == "true": + return True + if lowered == "false": + return False + + if re.fullmatch(r"-?\d+", value.strip()): + return int(value.strip()) + + if re.fullmatch(r"-?\d+\.\d+", value.strip()): + return float(value.strip()) + + if (value.startswith('"') and value.endswith('"')) or ( + value.startswith("'") and value.endswith("'") + ): + return value[1:-1] + + return value + + +def _ensure_flags(state: State) -> set[str]: + flags = state.get("flags") + if flags is None: + flags = set() + state["flags"] = flags + if isinstance(flags, set): + return flags + if isinstance(flags, list): + converted = set(str(x) for x in flags) + state["flags"] = converted + return converted + raise DslError("state['flags'] must be a set or list") + + +def apply_effect(effect: str, state: State) -> None: + line = effect.strip() + if not line: + return + + m = re.fullmatch(r"([A-Za-z_][A-Za-z0-9_]*)\s*(\+=|-=|=)\s*(.+)", line) + if m: + name, op, raw_value = m.groups() + value = _parse_literal(raw_value) + + if op == "=": + state[name] = value + return + + if name == "flags": + flags = _ensure_flags(state) + token = str(value) + if op == "+=": + flags.add(token) + else: + flags.discard(token) + return + + current = state.get(name, 0) + if not isinstance(current, (int, float)) or not isinstance(value, (int, float)): + raise DslError( + f"Only numeric variables support '{op}' (except flags), got {name}" + ) + state[name] = current + value if op == "+=" else current - value + return + + raise DslError(f"Invalid effect syntax: {effect}") + + +def apply_effects(effects: list[str], state: State) -> None: + for effect in effects: + apply_effect(effect, state) diff --git a/src/engine.py b/src/engine.py new file mode 100644 index 0000000..70d23bf --- /dev/null +++ b/src/engine.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from pathlib import Path + +from src import cli +from src.dsl import DslError, apply_effects, evaluate_condition +from src.models import Node, Option, State, Story +from src.storage import load_game, save_game + +DEFAULT_SAVE_PATH = Path("saves/latest.json") + + +class StoryEngine: + def __init__(self, story: Story, initial_state: State | None = None) -> None: + self.story = story + self.current_node_id = story.start + self.state: State = dict(initial_state or {}) + self._pending_node_effects = True + + def get_current_node(self) -> Node: + self._apply_pending_node_effects() + return self.story.nodes[self.current_node_id] + + def get_visible_options(self) -> list[Option]: + node = self.get_current_node() + return [ + option + for option in node.options + if evaluate_condition(option.condition, self.state) + ] + + def is_finished(self) -> bool: + node = self.get_current_node() + return node.end or not self.get_visible_options() + + def choose(self, visible_index: int) -> None: + visible_options = self.get_visible_options() + if visible_index < 0 or visible_index >= len(visible_options): + raise IndexError("Option index out of range") + + option = visible_options[visible_index] + apply_effects(option.effects, self.state) + self.current_node_id = option.target + self._pending_node_effects = True + + def save(self, save_path: str | Path = DEFAULT_SAVE_PATH) -> Path: + path = Path(save_path) + save_game(path, self.story.story_id, self.current_node_id, self.state) + return path + + def load(self, save_path: str | Path = DEFAULT_SAVE_PATH) -> Path: + path = Path(save_path) + data = load_game(path) + + if data.story_id != self.story.story_id: + raise ValueError( + f"Save story_id '{data.story_id}' does not match '{self.story.story_id}'" + ) + if data.current_node not in self.story.nodes: + raise ValueError( + f"Save current node '{data.current_node}' no longer exists in story" + ) + + self.current_node_id = data.current_node + self.state = data.state + self._pending_node_effects = False + return path + + def _apply_pending_node_effects(self) -> None: + if not self._pending_node_effects: + return + node = self.story.nodes[self.current_node_id] + apply_effects(node.effects, self.state) + self._pending_node_effects = False + + +class InteractiveRunner: + def __init__(self, engine: StoryEngine) -> None: + self.engine = engine + + def run(self) -> None: + print("欢迎来到 simple-ink。输入 :help 查看命令。") + while True: + node = self.engine.get_current_node() + options = self.engine.get_visible_options() + cli.render_node(node, options) + + if node.end or not options: + print("\n[游戏结束]\n") + return + + raw = cli.read_input() + if not raw: + continue + + if raw.startswith(":"): + if self._handle_command(raw): + return + continue + + try: + index = int(raw) - 1 + self.engine.choose(index) + except ValueError: + print("请输入有效编号或命令。") + except IndexError: + print("选项编号超出范围。") + except DslError as exc: + print(f"执行效果失败: {exc}") + + def _handle_command(self, raw: str) -> bool: + parts = raw.split(maxsplit=1) + command = parts[0] + argument = parts[1] if len(parts) > 1 else None + + if command == ":help": + cli.render_help() + return False + + if command == ":state": + print(self.engine.state) + return False + + if command == ":save": + path = Path(argument) if argument else DEFAULT_SAVE_PATH + try: + resolved = self.engine.save(path) + print(f"已保存到: {resolved}") + except Exception as exc: # pragma: no cover - defensive + print(f"保存失败: {exc}") + return False + + if command == ":load": + path = Path(argument) if argument else DEFAULT_SAVE_PATH + try: + resolved = self.engine.load(path) + print(f"已加载: {resolved}") + except Exception as exc: + print(f"加载失败: {exc}") + return False + + if command == ":quit": + print("已退出。") + return True + + print("未知命令,输入 :help 查看可用命令。") + return False diff --git a/src/models.py b/src/models.py new file mode 100644 index 0000000..51dfbea --- /dev/null +++ b/src/models.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(slots=True) +class Option: + text: str + target: str + condition: str | None = None + effects: list[str] = field(default_factory=list) + + +@dataclass(slots=True) +class Node: + id: str + text: str + options: list[Option] = field(default_factory=list) + effects: list[str] = field(default_factory=list) + end: bool = False + + +@dataclass(slots=True) +class Story: + story_id: str + start: str + nodes: dict[str, Node] + + +State = dict[str, Any] diff --git a/src/parser.py b/src/parser.py new file mode 100644 index 0000000..4cccbb2 --- /dev/null +++ b/src/parser.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import yaml + +from src.dsl import DslError, evaluate_condition, apply_effect +from src.models import Node, Option, Story + + +class StoryValidationError(ValueError): + """Raised when a story configuration is invalid.""" + + +class ParsedStory: + def __init__(self, story: Story, warnings: list[str] | None = None) -> None: + self.story = story + self.warnings = warnings or [] + + +def _to_effects(raw: Any) -> list[str]: + if raw is None: + return [] + if isinstance(raw, str): + return [raw] + if isinstance(raw, list) and all(isinstance(x, str) for x in raw): + return raw + raise StoryValidationError("'effect(s)' must be a string or list of strings") + + +def parse_story(file_path: str | Path) -> ParsedStory: + path = Path(file_path) + if not path.exists(): + raise StoryValidationError(f"Story file not found: {path}") + + with path.open("r", encoding="utf-8") as f: + data = yaml.safe_load(f) + + if not isinstance(data, dict): + raise StoryValidationError("Story root must be an object") + + story_id = data.get("story_id", path.stem) + start = data.get("start") + node_items = data.get("nodes") + + if not isinstance(story_id, str) or not story_id: + raise StoryValidationError("'story_id' must be a non-empty string") + if not isinstance(start, str) or not start: + raise StoryValidationError("'start' must be a non-empty string") + if not isinstance(node_items, list) or not node_items: + raise StoryValidationError("'nodes' must be a non-empty list") + + nodes: dict[str, Node] = {} + + for item in node_items: + if not isinstance(item, dict): + raise StoryValidationError("Each node must be an object") + + node_id = item.get("id") + text = item.get("text") + end = bool(item.get("end", False)) + node_effects = _to_effects(item.get("effects") if "effects" in item else item.get("effect")) + + if not isinstance(node_id, str) or not node_id: + raise StoryValidationError("Each node requires a non-empty string 'id'") + if node_id in nodes: + raise StoryValidationError(f"Duplicate node id: {node_id}") + if not isinstance(text, str): + raise StoryValidationError(f"Node '{node_id}' must have string 'text'") + + options_raw = item.get("options", []) + if not isinstance(options_raw, list): + raise StoryValidationError(f"Node '{node_id}' field 'options' must be a list") + + options: list[Option] = [] + for opt in options_raw: + if not isinstance(opt, dict): + raise StoryValidationError(f"Node '{node_id}' has non-object option") + opt_text = opt.get("text") + target = opt.get("target") + condition = opt.get("condition") + opt_effects = _to_effects(opt.get("effects") if "effects" in opt else opt.get("effect")) + + if not isinstance(opt_text, str) or not opt_text: + raise StoryValidationError(f"Node '{node_id}' has option with invalid 'text'") + if not isinstance(target, str) or not target: + raise StoryValidationError(f"Node '{node_id}' option '{opt_text}' has invalid 'target'") + if condition is not None and not isinstance(condition, str): + raise StoryValidationError( + f"Node '{node_id}' option '{opt_text}' has non-string 'condition'" + ) + + options.append( + Option( + text=opt_text, + target=target, + condition=condition, + effects=opt_effects, + ) + ) + + for effect in node_effects: + try: + apply_effect(effect, {}) + except DslError as exc: + raise StoryValidationError( + f"Node '{node_id}' has invalid effect '{effect}': {exc}" + ) from exc + + for option in options: + if option.condition: + try: + evaluate_condition(option.condition, {}) + except DslError as exc: + raise StoryValidationError( + f"Node '{node_id}' option '{option.text}' has invalid condition " + f"'{option.condition}': {exc}" + ) from exc + for effect in option.effects: + try: + apply_effect(effect, {}) + except DslError as exc: + raise StoryValidationError( + f"Node '{node_id}' option '{option.text}' has invalid effect " + f"'{effect}': {exc}" + ) from exc + + nodes[node_id] = Node( + id=node_id, + text=text, + options=options, + effects=node_effects, + end=end, + ) + + if start not in nodes: + raise StoryValidationError(f"'start' node '{start}' does not exist") + + for node in nodes.values(): + if not node.end and not node.options: + raise StoryValidationError( + f"Node '{node.id}' is not an end node and has no options" + ) + for option in node.options: + if option.target not in nodes: + raise StoryValidationError( + f"Node '{node.id}' option '{option.text}' targets missing node '{option.target}'" + ) + + warnings: list[str] = [] + reachable = _reachable_nodes(start, nodes) + for node_id in nodes: + if node_id not in reachable: + warnings.append(f"Node '{node_id}' is unreachable from start '{start}'") + + return ParsedStory(story=Story(story_id=story_id, start=start, nodes=nodes), warnings=warnings) + + +def _reachable_nodes(start: str, nodes: dict[str, Node]) -> set[str]: + visited: set[str] = set() + stack = [start] + while stack: + node_id = stack.pop() + if node_id in visited: + continue + visited.add(node_id) + stack.extend(opt.target for opt in nodes[node_id].options) + return visited diff --git a/src/storage.py b/src/storage.py new file mode 100644 index 0000000..a3f46e2 --- /dev/null +++ b/src/storage.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from src.models import State + +SAVE_VERSION = 1 + + +@dataclass(slots=True) +class SaveData: + version: int + story_id: str + current_node: str + state: State + timestamp: str + + +def _serialize_value(value: Any) -> Any: + if isinstance(value, set): + return sorted(value) + if isinstance(value, dict): + return {k: _serialize_value(v) for k, v in value.items()} + if isinstance(value, list): + return [_serialize_value(v) for v in value] + return value + + +def _deserialize_state(state: dict[str, Any]) -> State: + result: State = dict(state) + flags = result.get("flags") + if isinstance(flags, list): + result["flags"] = set(str(x) for x in flags) + return result + + +def save_game( + file_path: str | Path, + story_id: str, + current_node: str, + state: State, +) -> None: + payload = SaveData( + version=SAVE_VERSION, + story_id=story_id, + current_node=current_node, + state=_serialize_value(state), + timestamp=datetime.now(timezone.utc).isoformat(), + ) + + path = Path(file_path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + json.dump(asdict(payload), f, ensure_ascii=False, indent=2) + + +def load_game(file_path: str | Path) -> SaveData: + path = Path(file_path) + if not path.exists(): + raise FileNotFoundError(f"Save file not found: {path}") + + with path.open("r", encoding="utf-8") as f: + payload = json.load(f) + + required = {"version", "story_id", "current_node", "state", "timestamp"} + if not isinstance(payload, dict) or not required.issubset(payload.keys()): + raise ValueError("Invalid save file format") + + if payload["version"] != SAVE_VERSION: + raise ValueError( + f"Unsupported save version: {payload['version']} (expected {SAVE_VERSION})" + ) + + if not isinstance(payload["state"], dict): + raise ValueError("Save 'state' must be an object") + + return SaveData( + version=payload["version"], + story_id=payload["story_id"], + current_node=payload["current_node"], + state=_deserialize_state(payload["state"]), + timestamp=payload["timestamp"], + ) diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000..1362feb --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from src.engine import StoryEngine +from src.parser import parse_story + + +def test_branch_and_effects() -> None: + parsed = parse_story("data/main_story.yaml") + engine = StoryEngine(parsed.story) + + assert engine.current_node_id == "intro" + + intro_options = engine.get_visible_options() + assert len(intro_options) == 2 + + # Choose alley, then pick lamp. + engine.choose(1) + assert engine.current_node_id == "alley" + engine.choose(0) + assert engine.current_node_id == "gate" + + options = engine.get_visible_options() + option_texts = [o.text for o in options] + assert "用油灯照亮告示" in option_texts + + +def test_condition_hides_option() -> None: + parsed = parse_story("data/main_story.yaml") + engine = StoryEngine(parsed.story) + + # intro -> gate without taking lamp + engine.choose(0) + options = engine.get_visible_options() + option_texts = [o.text for o in options] + assert "用油灯照亮告示" not in option_texts diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 0000000..b8edce2 --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from src.parser import StoryValidationError, parse_story + + +def test_parse_story_success() -> None: + parsed = parse_story(Path("data/main_story.yaml")) + assert parsed.story.story_id == "demo_story" + assert parsed.story.start == "intro" + assert "gate" in parsed.story.nodes + + +def test_parse_story_missing_target_fails(tmp_path: Path) -> None: + bad_story = tmp_path / "bad.yaml" + bad_story.write_text( + """ +story_id: bad +start: a +nodes: + - id: a + text: hello + options: + - text: go + target: missing +""".strip(), + encoding="utf-8", + ) + + with pytest.raises(StoryValidationError): + parse_story(bad_story) + + +def test_parse_story_invalid_effect_fails(tmp_path: Path) -> None: + bad_story = tmp_path / "bad_effect.yaml" + bad_story.write_text( + """ +story_id: bad_effect +start: a +nodes: + - id: a + text: hello + effects: + - hp **= 2 + end: true +""".strip(), + encoding="utf-8", + ) + + with pytest.raises(StoryValidationError): + parse_story(bad_story) diff --git a/tests/test_storage.py b/tests/test_storage.py new file mode 100644 index 0000000..ca0ff2a --- /dev/null +++ b/tests/test_storage.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from pathlib import Path + +from src.engine import StoryEngine +from src.parser import parse_story +from src.storage import load_game + + +def test_save_and_load_roundtrip(tmp_path: Path) -> None: + parsed = parse_story("data/main_story.yaml") + engine = StoryEngine(parsed.story) + + # intro -> gate + engine.choose(0) + assert engine.current_node_id == "gate" + + save_path = tmp_path / "save.json" + engine.save(save_path) + + payload = load_game(save_path) + assert payload.story_id == parsed.story.story_id + assert payload.current_node == "gate" + + engine2 = StoryEngine(parsed.story) + engine2.load(save_path) + assert engine2.current_node_id == "gate" + assert engine2.state == engine.state