compose_eval/test/
mod.rs

1use crate::{EvalConfig, Machine};
2use compose_error_codes::ErrorCode;
3use compose_library::diag::compose_codespan_reporting::term::termcolor::{
4    ColorChoice, StandardStream,
5};
6use compose_library::diag::{
7    FileError, FileResult, SourceDiagnostic, SourceResult, Warned, write_diagnostics,
8};
9use compose_library::{Library, Value, World, library};
10use compose_syntax::{FileId, Source};
11use ecow::{EcoVec, eco_format, eco_vec};
12use std::collections::HashMap;
13use std::fmt::Debug;
14use std::io::{Read, Write};
15use std::sync::Mutex;
16use tap::pipe::Pipe;
17
18#[cfg(test)]
19mod iterators;
20#[cfg(test)]
21mod snippets;
22
23pub struct TestWorld {
24    sources: Mutex<HashMap<FileId, Source>>,
25    entrypoint: FileId,
26    library: Library,
27}
28
29impl Clone for TestWorld {
30    fn clone(&self) -> Self {
31        Self {
32            sources: Mutex::new(self.sources.lock().unwrap().clone()),
33            entrypoint: self.entrypoint,
34            library: self.library.clone(),
35        }
36    }
37}
38
39impl Debug for TestWorld {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        // try to get the mutex lock
42
43        f.debug_struct("TestWorld")
44            .field("entrypoint", &self.entrypoint)
45            .pipe(|d| {
46                if let Ok(sources) = self.sources.try_lock() {
47                    d.field("sources", &sources)
48                } else {
49                    d.field("sources", &"<locked>")
50                }
51            })
52            .finish()
53    }
54}
55
56impl TestWorld {
57    pub fn from_str(text: &str) -> Self {
58        let entrypoint = FileId::new("main.comp");
59        let source = Source::new(entrypoint, text.to_string());
60        let mut sources = HashMap::new();
61        sources.insert(source.id(), source);
62
63        Self {
64            sources: Mutex::new(sources),
65            entrypoint,
66            library: library(),
67        }
68    }
69
70    pub fn new() -> Self {
71        Self::from_str("")
72    }
73
74    pub fn entrypoint_src(&self) -> Source {
75        self.sources
76            .lock()
77            .unwrap()
78            .get(&self.entrypoint)
79            .unwrap()
80            .clone()
81    }
82
83    pub fn edit_source(&self, file_id: FileId, editor: impl FnOnce(&mut Source)) {
84        let mut sources = self.sources.lock().unwrap();
85        let source = sources.get_mut(&file_id).unwrap();
86        editor(source);
87    }
88}
89
90impl World for TestWorld {
91    fn entry_point(&self) -> FileId {
92        self.entrypoint
93    }
94
95    fn source(&self, file_id: FileId) -> FileResult<Source> {
96        let sources = self
97            .sources
98            .lock()
99            .map_err(|e| FileError::Other(Some(eco_format!("{e}"))))?;
100        match sources.get(&file_id) {
101            Some(s) => Ok(s.clone()),
102            None => Err(FileError::NotFound(file_id.path().0.clone())),
103        }
104    }
105
106    fn library(&self) -> &Library {
107        &self.library
108    }
109
110    fn write(&self, f: &dyn Fn(&mut dyn Write) -> std::io::Result<()>) -> std::io::Result<()> {
111        f(&mut std::io::stdout())
112    }
113
114    fn read(&self, f: &dyn Fn(&mut dyn Read) -> std::io::Result<()>) -> std::io::Result<()> {
115        f(&mut std::io::stdin())
116    }
117}
118
119fn print_diagnostics(
120    world: &TestWorld,
121    errors: &[SourceDiagnostic],
122    warnings: &[SourceDiagnostic],
123) {
124    let stdout = StandardStream::stdout(ColorChoice::Always);
125    write_diagnostics(
126        world,
127        errors,
128        warnings,
129        &mut stdout.lock(),
130        &Default::default(),
131    )
132    .expect("failed to print diagnostics");
133}
134
135#[must_use]
136pub fn eval_code_with_vm(vm: &mut Machine, world: &TestWorld, input: &str) -> TestResult {
137    if input.is_empty() {
138        return TestResult {
139            value: Ok(Value::unit()),
140            warnings: eco_vec!(),
141            world: world.clone(),
142        };
143    }
144
145    let len_before_edit = world.entrypoint_src().nodes().len();
146    world.edit_source(world.entry_point(), |s| {
147        s.append(format!("{}{input}", if !s.text().is_empty() { "\n" } else { "" }).as_str())
148    });
149
150    let source = world.entrypoint_src();
151    let len_after_edit = source.nodes().len();
152
153    let Warned { value, warnings } = crate::eval_range(
154        &source,
155        len_before_edit..len_after_edit,
156        vm,
157        &EvalConfig {
158            include_syntax_warnings: true,
159        },
160    );
161
162    TestResult {
163        value,
164        warnings,
165        world: world.clone(),
166    }
167}
168
169pub struct TestResult {
170    pub value: SourceResult<Value>,
171    pub warnings: EcoVec<SourceDiagnostic>,
172    pub world: TestWorld,
173}
174
175impl TestResult {
176    #[track_caller]
177    pub fn assert_no_errors(self) -> Self {
178        match &self.value {
179            Ok(_) => {}
180            Err(errors) => {
181                print_diagnostics(&self.world, errors, &self.warnings);
182                panic!("expected no errors, but got: {:?}", errors)
183            }
184        }
185        self
186    }
187
188    #[track_caller]
189    pub fn assert_no_warnings(self) -> Self {
190        if !self.warnings.is_empty() {
191            print_diagnostics(&self.world, &self.warnings, &self.warnings);
192            panic!("expected no warnings, but got: {:?}", self.warnings);
193        }
194        self
195    }
196
197    #[track_caller]
198    pub fn assert_errors(self, expected_errors: &[ErrorCode]) -> Self {
199        match &self.value {
200            Ok(_) => panic!("expected errors, but got none"),
201            Err(errors) => {
202                if expected_errors.is_empty() {
203                    panic!("expected no errors, but got: {:?}", errors)
204                }
205                if errors
206                    .iter()
207                    .map(|e| e.code)
208                    .zip(expected_errors.iter().map(Some))
209                    .any(|(a, b)| a != b)
210                {
211                    print_diagnostics(&self.world, errors, &self.warnings);
212                    panic!(
213                        "expected errors: {:?}, but got: {:?}",
214                        expected_errors, errors
215                    )
216                }
217            }
218        }
219
220        self
221    }
222
223    #[track_caller]
224    pub fn assert_warnings(self, expected_warnings: &[ErrorCode]) -> Self {
225        if self
226            .warnings
227            .iter()
228            .map(|e| e.code)
229            .zip(expected_warnings.iter().map(Some))
230            .any(|(a, b)| a != b)
231        {
232            print_diagnostics(&self.world, &self.warnings, &self.warnings);
233            panic!(
234                "expected warnings: {:?}, but got: {:?}",
235                expected_warnings, self.warnings
236            )
237        }
238
239        self
240    }
241
242    pub fn get_value(self) -> Value {
243        self.value.expect("code failed to evaluate")
244    }
245
246    #[allow(unused)]
247    pub fn get_warnings(&self) -> EcoVec<SourceDiagnostic> {
248        self.warnings.clone()
249    }
250
251    #[allow(unused)]
252    pub fn get_errors(&self) -> EcoVec<SourceDiagnostic> {
253        match &self.value {
254            Ok(_) => eco_vec!(),
255            Err(errors) => errors.clone(),
256        }
257    }
258}
259
260#[must_use]
261pub fn eval_code(code: &str) -> TestResult {
262    let world = TestWorld::from_str("");
263    let mut vm = Machine::new(&world);
264    eval_code_with_vm(&mut vm, &world, code)
265}
266
267#[track_caller]
268pub fn assert_eval(code: &str) -> Value {
269    eval_code(code)
270        .assert_no_warnings()
271        .assert_no_errors()
272        .get_value()
273}
274
275#[track_caller]
276pub fn assert_eval_with_vm(vm: &mut Machine, world: &TestWorld, code: &str) -> Value {
277    eval_code_with_vm(vm, world, code)
278        .assert_no_warnings()
279        .assert_no_errors()
280        .get_value()
281}