1use crate::{EvalConfig, Machine};
2use compose_error_codes::ErrorCode;
3use compose_library::diag::compose_codespan_reporting::term::termcolor::{
4 ColorChoice, StandardStream,
5};
6use compose_library::diag::{
7 FileError, FileResult, SourceDiagnostic, SourceResult, Warned, write_diagnostics,
8};
9use compose_library::{Library, Value, World, library};
10use compose_syntax::{FileId, Source};
11use ecow::{EcoVec, eco_format, eco_vec};
12use std::collections::HashMap;
13use std::fmt::Debug;
14use std::io::{Read, Write};
15use std::sync::Mutex;
16use tap::pipe::Pipe;
17
18#[cfg(test)]
19mod iterators;
20#[cfg(test)]
21mod snippets;
22
23pub struct TestWorld {
24 sources: Mutex<HashMap<FileId, Source>>,
25 entrypoint: FileId,
26 library: Library,
27}
28
29impl Clone for TestWorld {
30 fn clone(&self) -> Self {
31 Self {
32 sources: Mutex::new(self.sources.lock().unwrap().clone()),
33 entrypoint: self.entrypoint,
34 library: self.library.clone(),
35 }
36 }
37}
38
39impl Debug for TestWorld {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.debug_struct("TestWorld")
44 .field("entrypoint", &self.entrypoint)
45 .pipe(|d| {
46 if let Ok(sources) = self.sources.try_lock() {
47 d.field("sources", &sources)
48 } else {
49 d.field("sources", &"<locked>")
50 }
51 })
52 .finish()
53 }
54}
55
56impl TestWorld {
57 pub fn from_str(text: &str) -> Self {
58 let entrypoint = FileId::new("main.comp");
59 let source = Source::new(entrypoint, text.to_string());
60 let mut sources = HashMap::new();
61 sources.insert(source.id(), source);
62
63 Self {
64 sources: Mutex::new(sources),
65 entrypoint,
66 library: library(),
67 }
68 }
69
70 pub fn new() -> Self {
71 Self::from_str("")
72 }
73
74 pub fn entrypoint_src(&self) -> Source {
75 self.sources
76 .lock()
77 .unwrap()
78 .get(&self.entrypoint)
79 .unwrap()
80 .clone()
81 }
82
83 pub fn edit_source(&self, file_id: FileId, editor: impl FnOnce(&mut Source)) {
84 let mut sources = self.sources.lock().unwrap();
85 let source = sources.get_mut(&file_id).unwrap();
86 editor(source);
87 }
88}
89
90impl World for TestWorld {
91 fn entry_point(&self) -> FileId {
92 self.entrypoint
93 }
94
95 fn source(&self, file_id: FileId) -> FileResult<Source> {
96 let sources = self
97 .sources
98 .lock()
99 .map_err(|e| FileError::Other(Some(eco_format!("{e}"))))?;
100 match sources.get(&file_id) {
101 Some(s) => Ok(s.clone()),
102 None => Err(FileError::NotFound(file_id.path().0.clone())),
103 }
104 }
105
106 fn library(&self) -> &Library {
107 &self.library
108 }
109
110 fn write(&self, f: &dyn Fn(&mut dyn Write) -> std::io::Result<()>) -> std::io::Result<()> {
111 f(&mut std::io::stdout())
112 }
113
114 fn read(&self, f: &dyn Fn(&mut dyn Read) -> std::io::Result<()>) -> std::io::Result<()> {
115 f(&mut std::io::stdin())
116 }
117}
118
119fn print_diagnostics(
120 world: &TestWorld,
121 errors: &[SourceDiagnostic],
122 warnings: &[SourceDiagnostic],
123) {
124 let stdout = StandardStream::stdout(ColorChoice::Always);
125 write_diagnostics(
126 world,
127 errors,
128 warnings,
129 &mut stdout.lock(),
130 &Default::default(),
131 )
132 .expect("failed to print diagnostics");
133}
134
135#[must_use]
136pub fn eval_code_with_vm(vm: &mut Machine, world: &TestWorld, input: &str) -> TestResult {
137 if input.is_empty() {
138 return TestResult {
139 value: Ok(Value::unit()),
140 warnings: eco_vec!(),
141 world: world.clone(),
142 };
143 }
144
145 let len_before_edit = world.entrypoint_src().nodes().len();
146 world.edit_source(world.entry_point(), |s| {
147 s.append(format!("{}{input}", if !s.text().is_empty() { "\n" } else { "" }).as_str())
148 });
149
150 let source = world.entrypoint_src();
151 let len_after_edit = source.nodes().len();
152
153 let Warned { value, warnings } = crate::eval_range(
154 &source,
155 len_before_edit..len_after_edit,
156 vm,
157 &EvalConfig {
158 include_syntax_warnings: true,
159 },
160 );
161
162 TestResult {
163 value,
164 warnings,
165 world: world.clone(),
166 }
167}
168
169pub struct TestResult {
170 pub value: SourceResult<Value>,
171 pub warnings: EcoVec<SourceDiagnostic>,
172 pub world: TestWorld,
173}
174
175impl TestResult {
176 #[track_caller]
177 pub fn assert_no_errors(self) -> Self {
178 match &self.value {
179 Ok(_) => {}
180 Err(errors) => {
181 print_diagnostics(&self.world, errors, &self.warnings);
182 panic!("expected no errors, but got: {:?}", errors)
183 }
184 }
185 self
186 }
187
188 #[track_caller]
189 pub fn assert_no_warnings(self) -> Self {
190 if !self.warnings.is_empty() {
191 print_diagnostics(&self.world, &self.warnings, &self.warnings);
192 panic!("expected no warnings, but got: {:?}", self.warnings);
193 }
194 self
195 }
196
197 #[track_caller]
198 pub fn assert_errors(self, expected_errors: &[ErrorCode]) -> Self {
199 match &self.value {
200 Ok(_) => panic!("expected errors, but got none"),
201 Err(errors) => {
202 if expected_errors.is_empty() {
203 panic!("expected no errors, but got: {:?}", errors)
204 }
205 if errors
206 .iter()
207 .map(|e| e.code)
208 .zip(expected_errors.iter().map(Some))
209 .any(|(a, b)| a != b)
210 {
211 print_diagnostics(&self.world, errors, &self.warnings);
212 panic!(
213 "expected errors: {:?}, but got: {:?}",
214 expected_errors, errors
215 )
216 }
217 }
218 }
219
220 self
221 }
222
223 #[track_caller]
224 pub fn assert_warnings(self, expected_warnings: &[ErrorCode]) -> Self {
225 if self
226 .warnings
227 .iter()
228 .map(|e| e.code)
229 .zip(expected_warnings.iter().map(Some))
230 .any(|(a, b)| a != b)
231 {
232 print_diagnostics(&self.world, &self.warnings, &self.warnings);
233 panic!(
234 "expected warnings: {:?}, but got: {:?}",
235 expected_warnings, self.warnings
236 )
237 }
238
239 self
240 }
241
242 pub fn get_value(self) -> Value {
243 self.value.expect("code failed to evaluate")
244 }
245
246 #[allow(unused)]
247 pub fn get_warnings(&self) -> EcoVec<SourceDiagnostic> {
248 self.warnings.clone()
249 }
250
251 #[allow(unused)]
252 pub fn get_errors(&self) -> EcoVec<SourceDiagnostic> {
253 match &self.value {
254 Ok(_) => eco_vec!(),
255 Err(errors) => errors.clone(),
256 }
257 }
258}
259
260#[must_use]
261pub fn eval_code(code: &str) -> TestResult {
262 let world = TestWorld::from_str("");
263 let mut vm = Machine::new(&world);
264 eval_code_with_vm(&mut vm, &world, code)
265}
266
267#[track_caller]
268pub fn assert_eval(code: &str) -> Value {
269 eval_code(code)
270 .assert_no_warnings()
271 .assert_no_errors()
272 .get_value()
273}
274
275#[track_caller]
276pub fn assert_eval_with_vm(vm: &mut Machine, world: &TestWorld, code: &str) -> Value {
277 eval_code_with_vm(vm, world, code)
278 .assert_no_warnings()
279 .assert_no_errors()
280 .get_value()
281}