1use crate::{Machine};
2use compose_error_codes::ErrorCode;
3use compose_library::diag::compose_codespan_reporting::term::termcolor::{
4 ColorChoice, StandardStream,
5};
6use compose_library::diag::{FileError, FileResult, SourceDiagnostic, SourceResult, Warned, write_diagnostics, write_diagnostics_to_string};
7use compose_library::{Library, Value, World, };
8use compose_syntax::{FileId, Source};
9use ecow::{EcoVec, eco_format, eco_vec};
10use std::collections::HashMap;
11use std::fmt::Debug;
12use std::io::{Read, Write};
13use std::sync::Mutex;
14use tap::pipe::Pipe;
15
16mod flow_variables;
17#[cfg(test)]
18mod iterators;
19#[cfg(test)]
20mod snippets;
21
22pub struct TestWorld {
23 sources: Mutex<HashMap<FileId, Source>>,
24 entrypoint: FileId,
25 library: Library,
26 stdout: Mutex<String>,
27}
28
29impl Clone for TestWorld {
30 fn clone(&self) -> Self {
31 Self {
32 sources: Mutex::new(self.sources.lock().unwrap().clone()),
33 entrypoint: self.entrypoint,
34 library: self.library.clone(),
35 stdout: Mutex::new(String::new()),
36 }
37 }
38}
39
40impl Debug for TestWorld {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 f.debug_struct("TestWorld")
45 .field("entrypoint", &self.entrypoint)
46 .pipe(|d| {
47 if let Ok(sources) = self.sources.try_lock() {
48 d.field("sources", &sources)
49 } else {
50 d.field("sources", &"<locked>")
51 }
52 })
53 .finish()
54 }
55}
56
57impl TestWorld {
58 pub fn from_str(text: &str) -> Self {
59 let entrypoint = FileId::new("main.comp");
60 let source = Source::new(entrypoint, text.to_string());
61 let mut sources = HashMap::new();
62 sources.insert(source.id(), source);
63
64 Self {
65 sources: Mutex::new(sources),
66 entrypoint,
67 library: Library::default(),
68 stdout: Mutex::new(String::new()),
69 }
70 }
71
72 pub fn new() -> Self {
73 Self::from_str("")
74 }
75
76 pub fn entrypoint_src(&self) -> Source {
77 self.sources
78 .lock()
79 .unwrap()
80 .get(&self.entrypoint)
81 .unwrap()
82 .clone()
83 }
84
85 pub fn edit_source(&self, file_id: FileId, editor: impl FnOnce(&mut Source)) {
86 let mut sources = self.sources.lock().unwrap();
87 let source = sources.get_mut(&file_id).unwrap();
88 editor(source);
89 }
90}
91
92impl World for TestWorld {
93 fn entry_point(&self) -> FileId {
94 self.entrypoint
95 }
96
97 fn source(&self, file_id: FileId) -> FileResult<Source> {
98 let sources = self
99 .sources
100 .lock()
101 .map_err(|e| FileError::Other(Some(eco_format!("{e}"))))?;
102 match sources.get(&file_id) {
103 Some(s) => Ok(s.clone()),
104 None => Err(FileError::NotFound(file_id.path().0.clone())),
105 }
106 }
107
108 fn library(&self) -> &Library {
109 &self.library
110 }
111
112 fn write(
113 &self,
114 f: &mut dyn FnMut(&mut dyn Write) -> std::io::Result<()>,
115 ) -> std::io::Result<()> {
116 let mut buffer: Vec<u8> = Vec::new();
117 f(&mut buffer)?;
118 let output = String::from_utf8(buffer).expect("Invalid UTF-8");
119 self.stdout.lock().expect("failed to lock stdout").push_str(&output);
120 Ok(())
121 }
122
123 fn read(&self, f: &mut dyn FnMut(&mut dyn Read) -> std::io::Result<()>) -> std::io::Result<()> {
124 f(&mut std::io::stdin())
125 }
126}
127
128pub fn print_diagnostics(
129 world: &TestWorld,
130 errors: &[SourceDiagnostic],
131 warnings: &[SourceDiagnostic],
132) {
133 let stdout = StandardStream::stdout(ColorChoice::Always);
134 write_diagnostics(
135 world,
136 errors,
137 warnings,
138 &mut stdout.lock(),
139 &Default::default(),
140 )
141 .expect("failed to print diagnostics");
142}
143
144#[must_use]
145pub fn eval_code_with_vm(vm: &mut Machine, world: &TestWorld, input: &str) -> TestResult {
146 if input.is_empty() {
147 return TestResult {
148 value: Ok(Value::unit()),
149 warnings: eco_vec!(),
150 world: world.clone(),
151 };
152 }
153
154 let len_before_edit = world.entrypoint_src().nodes().len();
155 world.edit_source(world.entry_point(), |s| {
156 s.append(format!("{}{input}", if !s.text().is_empty() { "\n" } else { "" }).as_str())
157 });
158
159 let source = world.entrypoint_src();
160 let len_after_edit = source.nodes().len();
161
162 let Warned { value, warnings } = crate::eval_source_range(
163 &source,
164 len_before_edit..len_after_edit,
165 vm,
166 );
167
168 TestResult {
169 value,
170 warnings,
171 world: world.clone(),
172 }
173}
174
175pub struct TestResult {
176 pub value: SourceResult<Value>,
177 pub warnings: EcoVec<SourceDiagnostic>,
178 pub world: TestWorld,
179}
180
181impl TestResult {
182 #[track_caller]
183 pub fn assert_no_errors(self) -> Self {
184 match &self.value {
185 Ok(_) => {}
186 Err(errors) => {
187 print_diagnostics(&self.world, errors, &self.warnings);
188 panic!("expected no errors, but got: {:?}", errors)
189 }
190 }
191 self
192 }
193
194 pub fn errors(&self) -> &[SourceDiagnostic] {
195 match &self.value {
196 Ok(_) => &[],
197 Err(errors) => errors,
198 }
199 }
200
201 pub fn warnings(&self) -> &[SourceDiagnostic] {
202 &self.warnings
203 }
204
205 #[track_caller]
206 pub fn assert_no_warnings(self) -> Self {
207 if !self.warnings.is_empty() {
208 print_diagnostics(&self.world, &self.warnings, &self.warnings);
209 panic!("expected no warnings, but got: {:?}", self.warnings);
210 }
211 self
212 }
213
214 #[track_caller]
215 pub fn assert_errors(self, expected_errors: &[&ErrorCode]) -> Self {
216 match &self.value {
217 Ok(_) if expected_errors.is_empty() => {}
218 Ok(_) => panic!("expected errors, but got none"),
219 Err(errors) => {
220 let missing_expected_errors = expected_errors
221 .iter()
222 .filter(|e| !errors.iter().any(|c| c.code == Some(*e)))
223 .map(|e| e.code)
224 .collect::<Vec<_>>();
225 let unexpected_errors = errors
226 .iter()
227 .filter(|d| {
228 !expected_errors
229 .iter()
230 .any(|e| Some(e.code) == d.code.map(|c| c.code))
231 })
232 .cloned()
233 .collect::<Vec<_>>();
234
235 let mut error = String::new();
236 if !missing_expected_errors.is_empty() {
237 error.push_str("expected errors did not occur: ");
238 error.push_str(&missing_expected_errors.join(", "));
239 error.push_str("\n");
240 }
241
242 if !unexpected_errors.is_empty() {
243 error.push_str("unexpected errors occurred: \n");
244 error.push_str(write_diagnostics_to_string(&self.world, &unexpected_errors, &[]).as_str());
245 }
246
247 if !missing_expected_errors.is_empty() || !unexpected_errors.is_empty() {
248 panic!("{}", error);
249 }
250 }
251 }
252
253 self
254 }
255
256 #[track_caller]
257 pub fn assert_warnings(self, expected_warnings: &[&ErrorCode]) -> Self {
258 if self
259 .warnings
260 .iter()
261 .map(|e| e.code)
262 .zip(expected_warnings.iter().map(Some))
263 .any(|(a, b)| a.as_ref() != b)
264 {
265 print_diagnostics(&self.world, &self.warnings, &self.warnings);
266 panic!(
267 "expected warnings: {:?}, but got: {:?}",
268 expected_warnings, self.warnings
269 )
270 }
271
272 self
273 }
274
275 pub fn get_value(self) -> Value {
276 self.value.expect("code failed to evaluate")
277 }
278
279 #[allow(unused)]
280 pub fn get_warnings(&self) -> EcoVec<SourceDiagnostic> {
281 self.warnings.clone()
282 }
283
284 #[allow(unused)]
285 pub fn get_errors(&self) -> EcoVec<SourceDiagnostic> {
286 match &self.value {
287 Ok(_) => eco_vec!(),
288 Err(errors) => errors.clone(),
289 }
290 }
291
292 pub fn assert_stdout(&self, expected: &str) {
293 let stdout = self.world.stdout.lock().expect("failed to lock stdout").clone();
294 assert_eq!(stdout, expected);
295 }
296
297 pub fn assert_stdout_predicate(&self, predicate: impl FnOnce(&str) -> bool) {
298 let stdout = self.world.stdout.lock().expect("failed to lock stdout").clone();
299 assert!(predicate(&stdout));
300 }
301}
302
303#[must_use]
304pub fn eval_code(code: &str) -> TestResult {
305 let world = TestWorld::from_str("");
306 let mut vm = Machine::new(&world);
307 eval_code_with_vm(&mut vm, &world, code)
308}
309
310#[track_caller]
311pub fn assert_eval(code: &str) -> Value {
312 eval_code(code)
313 .assert_no_warnings()
314 .assert_no_errors()
315 .get_value()
316}
317
318#[track_caller]
319pub fn assert_eval_with_vm(vm: &mut Machine, world: &TestWorld, code: &str) -> Value {
320 eval_code_with_vm(vm, world, code)
321 .assert_no_warnings()
322 .assert_no_errors()
323 .get_value()
324}