1use crate::vm::{FlowEvent, TrackedContainer};
2use crate::{Eval, Evaluated, Machine, ValueEvaluatedExtensions};
3use compose_library::diag::{IntoSourceDiagnostic, SourceResult, Spanned, bail, error};
4use compose_library::{
5 Args, Binding, BindingKind, Closure, Func, IntoValue, Library, Scope, Scopes, Value,
6 VariableAccessError, Visibility,
7};
8use compose_syntax::ast::{AstNode, Expr, Ident, Param, ParamKind};
9use compose_syntax::{Label, Span, SyntaxNode, ast};
10use ecow::{EcoString, EcoVec};
11use std::collections::HashMap;
12
13impl Eval for ast::Lambda<'_> {
14 fn eval(self, vm: &mut Machine) -> SourceResult<Evaluated> {
15 let guard = vm.temp_root_guard();
16
17 let mut defaults = Vec::new();
18 for param in self.params().children() {
19 if let ast::ParamKind::Named(named) = param.kind() {
20 defaults.push(named.expr().eval(guard.vm)?.value);
21 }
22 }
23
24 let captured = {
25 let mut errors = EcoVec::new();
26 let mut scope = Scope::new();
27 for capture in self.captures().children() {
28 let span = capture.binding().span();
29 let name = capture.binding().get();
30 let binding = match guard.vm.get(&capture.binding()).cloned() {
31 Ok(v) => v,
32 Err(e) => {
33 let VariableAccessError::Unbound(unbound) = e else {
34 errors.push(e.into_source_diagnostic(span));
35 continue;
36 };
37
38 let mut err = error!(
39 span, "unknown variable `{name}` in closure capture list";
40 label_message: "variable `{name}` is not defined in the outer scope and cannot be captured";
41 );
42 if let Some(msg) = unbound.misspellings_hint_message() {
43 err.hint(msg);
44 }
45 errors.push(err);
46
47 continue;
48 }
49 };
50
51 let value = match validate_capture(capture, &binding, guard.vm) {
52 Ok(v) => v,
53 Err(e) => {
54 errors.extend(e.into_iter());
55 continue;
56 }
57 };
58
59 scope.bind(
60 name.clone(),
61 Binding::new(value.clone(), span).with_kind(match capture.is_mut() {
62 true => BindingKind::Mutable,
63 false => BindingKind::Immutable { first_assign: None },
64 }),
65 );
66 }
67
68 if !errors.is_empty() {
69 return Err(errors);
70 }
71
72 scope
73 };
74
75 let unresolved_captures = {
76 let mut visitor = CapturesVisitor::new(
77 &guard.vm.frames.top.scopes,
78 Some(guard.vm.engine.world.library()),
79 &captured,
80 );
81 visitor.visit_lambda(self);
82 visitor.finish()
83 };
84
85 let closure = Closure {
86 name: None,
87 node: self.to_untyped().clone(),
88 defaults,
89 num_pos_params: self
90 .params()
91 .children()
92 .filter(|p| matches!(p.kind(), ast::ParamKind::Pos(_)))
93 .count(),
94 captured,
95 unresolved_captures,
96 };
97
98 if !guard.vm.context.closure_capture.should_defer() {
99 closure.resolve_captures()?
100 }
101
102 Ok(Func::from(closure)
103 .into_value()
104 .spanned(self.span())
105 .mutable())
106 }
107}
108
109fn validate_capture<'a>(
110 capture: ast::Capture,
111 binding: &'a Binding,
112 vm: &mut Machine,
113) -> SourceResult<&'a Value> {
114 let span = capture.binding().span();
115 let name = capture.binding().get();
116 if capture.is_mut() && capture.is_ref() && !binding.kind().is_mut() {
117 bail!(
118 capture.span(), "cannot capture variable `{name}` as `ref mut` because it is not declared as mutable";
119 label_message: "capture is declared as a mutable reference";
120 label: Label::secondary(binding.span(), "was defined as immutable here");
121 note: "captured mutable references must match the mutability of the original declaration";
122 hint: "declare the variable as mutable: `let mut {name} = ...`";
123 hint: "or remove `mut` from the capture: `|ref {name}, ...|";
124 );
125 }
126
127 let value = binding.read_checked(span, vm.sink_mut());
128
129 if capture.is_ref() && !value.is_box() {
130 bail!(
131 span, "cannot capture non reference type by reference";
132 label_message: "this captures by reference";
133 note: "only boxed values can be captured by reference"
134 );
135 }
136
137 Ok(value)
138}
139
140fn define(
141 vm: &mut Machine,
142 ident: Ident,
143 Spanned { value, span }: Spanned<Value>,
144 param: Param,
145) -> SourceResult<()> {
146 if param.is_ref() && !value.is_box() {
147 bail!(
148 ident.span(),
149 "cannot take a reference to a value type. consider boxing it first with `box::new(value)`"
150 );
151 }
152 if !param.is_ref() && value.is_box() {
153 bail!(span, "Cannot bind a box value to a non-reference parameter";
154 label_message: "this parameter is a boxed value";
155 label: Label::primary(ident.span(), "this parameter is declared as an owned value");
156 note: "passing a reference type (like `box`) into an owned parameter is not allowed";
157 hint: "if the parameter intends to modify the original value, mark the parameter as `ref mut`";
158 hint: "if you want to pass the value by ownership, use `.clone_inner()` to create a new copy";
159 );
160 }
161
162 let kind = if param.is_mut() {
163 BindingKind::ParamMut
164 } else {
165 BindingKind::Param
166 };
167 vm.define(ident, value, kind, Visibility::Private)?;
168
169 Ok(())
170}
171
172pub fn eval_lambda(closure: &Closure, vm: &mut Machine, args: Args) -> SourceResult<Value> {
174 let guard = vm.temp_root_guard();
175 let ast_closure = closure
176 .node
177 .cast::<ast::Lambda>()
178 .expect("closure is not an ast closure");
179 let params = ast_closure.params();
180 let statements = ast_closure.statements();
181
182 guard.vm.track_tmp_root(&args);
184
185 let result = guard
186 .vm
187 .with_frame(move |vm| {
188 let mut args = args;
189 if let Some(Spanned { value, span }) = &closure.name {
190 vm.try_bind(
191 value.clone(),
192 Binding::new(Func::from(closure.clone()), *span),
193 )?;
194 }
195
196 for (k, v) in closure.captured.bindings() {
197 vm.bind(k.clone(), v.clone());
198 }
199
200 let mut defaults = closure.defaults.iter();
201 for p in params.children() {
202 match p.kind() {
203 ast::ParamKind::Pos(pattern) => match pattern {
204 ast::Pattern::Single(Expr::Ident(ident)) => {
205 define(vm, ident, args.expect(&ident)?, p)?;
206 }
207 pattern => bail!(pattern.span(), "Patterns not supported in closures yet"),
208 },
209 ast::ParamKind::Named(named) => {
210 let name = named.name();
211 let default = defaults.next().unwrap();
212 let value = args
213 .named(&name)?
214 .unwrap_or_else(|| Spanned::new(default.clone(), named.expr().span()));
215 define(vm, name, value, p)?;
216 }
217 }
218 }
219
220 args.finish()?;
222
223 let mut output = Value::unit();
224 for statement in statements {
225 output = statement.eval(vm)?.value;
226 match &vm.flow {
227 None => {}
228 Some(FlowEvent::Return(_, Some(explicit))) => {
229 let explicit = explicit.clone();
230 vm.flow = None;
231 return Ok(explicit);
232 }
233 Some(FlowEvent::Return(_, None)) => {
234 vm.flow = None;
235 return Ok(Value::unit());
236 }
237 Some(other) => bail!(other.forbidden()),
238 }
239 }
240
241 SourceResult::Ok(output)
242 })
243 .track_tmp_root(guard.vm);
244
245 guard.vm.maybe_gc();
246
247 result
248}
249
250#[derive(Debug)]
252pub struct CapturesVisitor<'a> {
253 external: &'a Scopes<'a>,
255 internal: Scopes<'a>,
257 captures: HashMap<EcoString, Span>,
259}
260
261impl<'a> CapturesVisitor<'a> {
262 pub fn new(external: &'a Scopes<'a>, library: Option<&'a Library>, existing: &Scope) -> Self {
263 let mut inst = Self {
264 external,
265 internal: Scopes::new(library),
266 captures: HashMap::new(),
267 };
268
269 for (k, v) in existing.bindings() {
270 inst.internal.top.bind(k.clone(), v.clone());
271 }
272
273 inst
274 }
275 pub(crate) fn visit_lambda(&mut self, closure: ast::Lambda<'a>) {
276 for param in closure.params().children() {
277 match param.kind() {
278 ParamKind::Pos(pat) => {
279 for ident in pat.bindings() {
280 self.bind(ident);
281 }
282 }
283 ParamKind::Named(named) => {
284 self.bind(named.name());
285 }
286 }
287 }
288
289 for capture in closure.captures().children() {
290 self.visit(capture.to_untyped());
291 }
292
293 for statement in closure.statements() {
294 self.visit(statement.to_untyped());
295 }
296 }
297
298 pub fn visit(&mut self, node: &'a SyntaxNode) {
299 if let Some(ast::Statement::Let(let_binding)) = node.cast() {
300 if let Some(init) = let_binding.initial_value() {
301 self.visit(init.to_untyped())
302 }
303
304 for ident in let_binding.pattern().bindings() {
305 self.bind(ident);
306 }
307 return;
308 }
309
310 let expr = match node.cast::<Expr>() {
311 Some(expr) => expr,
312 None => {
313 if let Some(named) = node.cast::<ast::Named>() {
314 self.visit(named.expr().to_untyped());
316 return;
317 }
318
319 Expr::default()
320 }
321 };
322
323 match expr {
324 Expr::Ident(ident) => self.capture(ident),
325 Expr::CodeBlock(_) => {
326 self.internal.enter();
327 for child in node.children() {
328 self.visit(child);
329 }
330 self.internal.exit();
331 }
332 Expr::FieldAccess(access) => {
333 self.visit(access.target().to_untyped());
334 }
335 Expr::Lambda(closure) => {
336 for param in closure.params().children() {
337 if let ast::ParamKind::Named(named) = param.kind() {
338 self.visit(named.expr().to_untyped());
339 }
340 }
341
342 for capture in closure.captures().children() {
343 self.visit(capture.to_untyped());
344 }
345
346 }
351
352 Expr::ForLoop(for_loop) => {
353 self.visit(for_loop.iterable().to_untyped());
355
356 self.internal.enter();
357 let pattern = for_loop.binding();
358 for ident in pattern.bindings() {
359 self.bind(ident);
360 }
361
362 self.visit(for_loop.body().to_untyped());
363 self.internal.exit();
364 }
365
366 _ => {
367 for child in node.children() {
369 self.visit(child);
370 }
371 }
372 }
373 }
374
375 fn bind(&mut self, ident: Ident) {
376 self.internal.top.bind(
377 ident.get().clone(),
378 Binding::new(Value::unit(), ident.span()),
379 );
380 }
381
382 fn capture(&mut self, ident: Ident<'a>) {
383 if self.internal.get(&ident).is_ok() {
384 return;
386 }
387
388 if self.external.get(&ident).is_ok() {
390 self.captures
391 .entry(ident.get().clone())
392 .or_insert(ident.span());
393 }
394 }
395
396 fn finish(self) -> HashMap<EcoString, Span> {
397 self.captures
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use crate::expression::closure::CapturesVisitor;
404 use crate::test::*;
405 use compose_library::{Scope, Scopes};
406 use compose_syntax::{FileId, parse};
407
408 #[test]
409 fn capturing() {
410 assert_eval(
411 r#"
412 let a = 10;
413 let f = { |a| => a + 1 };
414 assert::eq(f(), 11);
415 "#,
416 );
417
418 assert_eval(
419 r#"
420 let mut a = box::new(5);
421 let f = { |ref mut a| =>
422 *a += 1;
423 *a;
424 };
425 assert::eq(f(), 6);
426 "#,
427 );
428
429 assert_eval(
430 r#"
431 let a = 2;
432 let b = 3;
433 let f = { |a, b| => a * b };
434 assert::eq(f(), 6);
435 "#,
436 );
437
438 assert_eval(
439 r#"
440 let mut name = box::new("Alice");
441 let age = 30;
442 let show = { |ref name, age| =>
443 "Name: " + *name + ", Age: " + age.to_string();
444 };
445 assert::eq(show(), "Name: Alice, Age: 30");
446 *name = "Bob";
447 assert::eq(show(), "Name: Bob, Age: 30");
448 "#,
449 );
450 }
451
452 #[track_caller]
453 fn test(scopes: &Scopes, existing_scope: &Scope, text: &str, expected_names: &[&str]) {
454 let mut visitor = CapturesVisitor::new(scopes, None, existing_scope);
455 let nodes = parse(text, FileId::new("test.comp"));
456
457 for node in &nodes {
458 assert!(node.errors().is_empty(), "node has errors: {:#?}", node.errors());
459 visitor.visit(node);
460 }
461
462 let captures = visitor.finish();
463 let mut names: Vec<_> = captures.iter().map(|(k, ..)| k).collect();
464 names.sort();
465
466 assert_eq!(names, expected_names);
467 }
468
469 #[test]
470 fn test_captures_visitor() {
471 let mut scopes = Scopes::new(None);
472 scopes.top.define("f", 0i64);
473 scopes.top.define("x", 0i64);
474 scopes.top.define("y", 0i64);
475 scopes.top.define("z", 0i64);
476 let s = &scopes;
477
478 let mut existing = Scope::new();
479 existing.define("a", 0i64);
480 existing.define("b", 0i64);
481 existing.define("c", 0i64);
482 let e = &existing;
483
484 test(s, e, "{ x => x * 2; }", &[]);
485
486 test(s, e, "let t = x;", &["x"]);
488 test(s, e, "let x = x;", &["x"]);
489 test(s, e, "let x;", &[]);
490 test(s, e, "let x = 2; x + y;", &["y"]);
491 test(s, e, "x + y", &["x", "y"]);
492
493 test(s, e, "x += y;", &["x", "y"]);
495 test(s, e, "x = y;", &["x", "y"]);
496
497 test(s, e, "let f = { => x + y; }", &[]);
500 test(s, e, "let f = { |x| => x + y; }", &["x"]);
502 test(s, e, "let f = { |x| => f(); }", &["x"]);
503 test(s, e, "let f = { x, y, z => f(); }", &[]);
505 test(
507 s,
508 e,
509 "let f = { x: x, y: y, z: z => f(); }",
510 &["x", "y", "z"],
511 );
512
513 test(s, e, "for (x in y) { x + z; };", &["y", "z"]);
515 test(s, e, "for (x in y) { x; }; x", &["x", "y"]);
516
517 test(s, e, "{ x; };", &["x"]);
519 test(s, e, "{ let x; x; };", &[]);
520 test(s, e, "{ let x; x; }; x;", &["x"]);
521
522 test(s, e, "x.y.f(z);", &["x", "z"]);
524
525 test(s, e, "(x + z);", &["x", "z"]);
527 test(s, e, "(({ x => x + y }) + y);", &["y"]);
528 }
529}