1use crate::vm::{FlowEvent, TrackedContainer};
2use crate::{Eval, Evaluated, Machine};
3use compose_library::diag::{IntoSourceDiagnostic, SourceResult, Spanned, bail, error};
4use compose_library::{Args, Binding, BindingKind, Closure, Func, Library, Scope, Scopes, Value, VariableAccessError, Visibility};
5use compose_syntax::ast::{AstNode, Expr, Ident, Param, ParamKind};
6use compose_syntax::{Label, Span, SyntaxNode, ast};
7use ecow::{EcoString, EcoVec};
8use std::collections::HashMap;
9
10impl Eval for ast::Lambda<'_> {
11 fn eval(self, vm: &mut Machine) -> SourceResult<Evaluated> {
12 let guard = vm.temp_root_guard();
13
14 let mut defaults = Vec::new();
15 for param in self.params().children() {
16 if let ast::ParamKind::Named(named) = param.kind() {
17 defaults.push(named.expr().eval(guard.vm)?.value);
18 }
19 }
20
21 let captured = {
22 let mut errors = EcoVec::new();
23 let mut scope = Scope::new();
24 for capture in self.captures().children() {
25 let span = capture.binding().span();
26 let name = capture.binding().get();
27 let binding = match guard.vm.get(&capture.binding()).cloned() {
28 Ok(v) => v,
29 Err(e) => {
30 let VariableAccessError::Unbound(unbound) = e else {
31 errors.push(e.into_source_diagnostic(span));
32 continue;
33 };
34
35 let mut err = error!(
36 span, "unknown variable `{name}` in closure capture list";
37 label_message: "variable `{name}` is not defined in the outer scope and cannot be captured";
38 );
39 unbound.apply_hint(&mut err);
40 errors.push(err);
41
42 continue;
43 }
44 };
45
46 let value = match validate_capture(capture, &binding, guard.vm) {
47 Ok(v) => v,
48 Err(e) => {
49 errors.extend(e.into_iter());
50 continue;
51 }
52 };
53
54 scope.bind(
55 name.clone(),
56 Binding::new(value.clone(), span).with_kind(match capture.is_mut() {
57 true => BindingKind::Mutable,
58 false => BindingKind::Immutable { first_assign: None },
59 }),
60 );
61 }
62
63 if !errors.is_empty() {
64 return Err(errors);
65 }
66
67 scope
68 };
69
70 let unresolved_captures = {
71 let mut visitor = CapturesVisitor::new(
72 &guard.vm.frames.top.scopes,
73 Some(guard.vm.engine.world.library()),
74 &captured,
75 );
76 visitor.visit_lambda(self);
77 visitor.finish()
78 };
79
80 let closure = Closure {
81 name: None,
82 node: self.to_untyped().clone(),
83 defaults,
84 num_pos_params: self
85 .params()
86 .children()
87 .filter(|p| matches!(p.kind(), ast::ParamKind::Pos(_)))
88 .count(),
89 captured,
90 unresolved_captures,
91 };
92
93 if !guard.vm.context.closure_capture.should_defer() {
94 closure.resolve()?
95 }
96
97 Ok(Evaluated::mutable(
98 Value::Func(Func::from(closure)).spanned(self.span()),
99 ))
100 }
101}
102
103fn validate_capture<'a>(
104 capture: ast::Capture,
105 binding: &'a Binding,
106 vm: &mut Machine,
107) -> SourceResult<&'a Value> {
108 let span = capture.binding().span();
109 let name = capture.binding().get();
110 if capture.is_mut() && capture.is_ref() && !binding.kind().is_mut() {
111 bail!(
112 capture.span(), "cannot capture variable `{name}` as `ref mut` because it is not declared as mutable";
113 label_message: "capture is declared as a mutable reference";
114 label: Label::secondary(binding.span(), "was defined as immutable here");
115 note: "captured mutable references must match the mutability of the original declaration";
116 hint: "declare the variable as mutable: `let mut {name} = ...`";
117 hint: "or remove `mut` from the capture: `|ref {name}, ...|";
118 );
119 }
120
121 let value = binding.read_checked(span, vm.sink_mut());
122
123 if capture.is_ref() && !value.is_box() {
124 bail!(
125 span, "cannot capture non reference type by reference";
126 label_message: "this captures by reference";
127 note: "only boxed values can be captured by reference"
128 );
129 }
130
131 Ok(value)
132}
133
134fn define(
135 vm: &mut Machine,
136 ident: Ident,
137 Spanned { value, span }: Spanned<Value>,
138 param: Param,
139) -> SourceResult<()> {
140 if param.is_ref() && !value.is_box() {
141 bail!(
142 ident.span(),
143 "cannot take a reference to a value type. consider boxing it first with `box::new(value)`"
144 );
145 }
146 if !param.is_ref() && value.is_box() {
147 bail!(span, "Cannot bind a box value to a non-reference parameter";
148 label_message: "this parameter is a boxed value";
149 label: Label::primary(ident.span(), "this parameter is declared as an owned value");
150 note: "passing a reference type (like `box`) into an owned parameter is not allowed";
151 hint: "if the parameter intends to modify the original value, mark the parameter as `ref mut`";
152 hint: "if you want to pass the value by ownership, use `.clone_inner()` to create a new copy";
153 );
154 }
155
156 let kind = if param.is_mut() {
157 BindingKind::ParamMut
158 } else {
159 BindingKind::Param
160 };
161 vm.define(ident, value, kind, Visibility::Private)?;
162
163 Ok(())
164}
165
166pub fn eval_lambda(closure: &Closure, vm: &mut Machine, args: Args) -> SourceResult<Value> {
168 let guard = vm.temp_root_guard();
169 let ast_closure = closure
170 .node
171 .cast::<ast::Lambda>()
172 .expect("closure is not an ast closure");
173 let params = ast_closure.params();
174 let statements = ast_closure.statements();
175
176 guard.vm.track_tmp_root(&args);
178
179 let result = guard
180 .vm
181 .with_frame(move |vm| {
182 let mut args = args;
183 if let Some(Spanned { value, span }) = &closure.name {
184 vm.try_bind(
185 value.clone(),
186 Binding::new(Func::from(closure.clone()), *span),
187 )?;
188 }
189
190 for (k, v) in closure.captured.bindings() {
191 vm.bind(k.clone(), v.clone());
192 }
193
194 let mut defaults = closure.defaults.iter();
195 for p in params.children() {
196 match p.kind() {
197 ast::ParamKind::Pos(pattern) => match pattern {
198 ast::Pattern::Single(Expr::Ident(ident)) => {
199 define(vm, ident, args.expect(&ident)?, p)?;
200 }
201 pattern => bail!(pattern.span(), "Patterns not supported in closures yet"),
202 },
203 ast::ParamKind::Named(named) => {
204 let name = named.name();
205 let default = defaults.next().unwrap();
206 let value = args
207 .named(&name)?
208 .unwrap_or_else(|| Spanned::new(default.clone(), named.expr().span()));
209 define(vm, name, value, p)?;
210 }
211 }
212 }
213
214 args.finish()?;
216
217 let mut output = Value::unit();
218 for statement in statements {
219 output = statement.eval(vm)?.value;
220 match &vm.flow {
221 None => {}
222 Some(FlowEvent::Return(_, Some(explicit))) => {
223 let explicit = explicit.clone();
224 vm.flow = None;
225 return Ok(explicit)
226 },
227 Some(FlowEvent::Return(_, None)) => {
228 vm.flow = None;
229 return Ok(Value::unit())
230 }
231 Some(other) => bail!(other.forbidden()),
232 }
233 }
234
235 SourceResult::Ok(output)
236 })
237 .track_tmp_root(guard.vm);
238
239 guard.vm.maybe_gc();
240
241 result
242}
243
244#[derive(Debug)]
246pub struct CapturesVisitor<'a> {
247 external: &'a Scopes<'a>,
249 internal: Scopes<'a>,
251 captures: HashMap<EcoString, Span>,
253}
254
255impl<'a> CapturesVisitor<'a> {
256 pub fn new(external: &'a Scopes<'a>, library: Option<&'a Library>, existing: &Scope) -> Self {
257 let mut inst = Self {
258 external,
259 internal: Scopes::new(library),
260 captures: HashMap::new(),
261 };
262
263 for (k, v) in existing.bindings() {
264 inst.internal.top.bind(k.clone(), v.clone());
265 }
266
267 inst
268 }
269 pub(crate) fn visit_lambda(&mut self, closure: ast::Lambda<'a>) {
270 for param in closure.params().children() {
271 match param.kind() {
272 ParamKind::Pos(pat) => {
273 for ident in pat.bindings() {
274 self.bind(ident);
275 }
276 }
277 ParamKind::Named(named) => {
278 self.bind(named.name());
279 }
280 }
281 }
282
283 for capture in closure.captures().children() {
284 self.visit(capture.to_untyped());
285 }
286
287 for statement in closure.statements() {
288 self.visit(statement.to_untyped());
289 }
290 }
291
292 pub fn visit(&mut self, node: &'a SyntaxNode) {
293 if let Some(ast::Statement::Let(let_binding)) = node.cast() {
294 if let Some(init) = let_binding.initial_value() {
295 self.visit(init.to_untyped())
296 }
297
298 for ident in let_binding.pattern().bindings() {
299 self.bind(ident);
300 }
301 return;
302 }
303
304 let expr = match node.cast::<Expr>() {
305 Some(expr) => expr,
306 None => {
307 if let Some(named) = node.cast::<ast::Named>() {
308 self.visit(named.expr().to_untyped());
310 return;
311 }
312
313 Expr::default()
314 }
315 };
316
317 match expr {
318 Expr::Ident(ident) => self.capture(ident),
319 Expr::CodeBlock(_) => {
320 self.internal.enter();
321 for child in node.children() {
322 self.visit(child);
323 }
324 self.internal.exit();
325 }
326 Expr::FieldAccess(access) => {
327 self.visit(access.target().to_untyped());
328 }
329 Expr::Lambda(closure) => {
330 for param in closure.params().children() {
331 if let ast::ParamKind::Named(named) = param.kind() {
332 self.visit(named.expr().to_untyped());
333 }
334 }
335
336 for capture in closure.captures().children() {
337 self.visit(capture.to_untyped());
338 }
339
340 }
345
346 Expr::ForLoop(for_loop) => {
347 self.visit(for_loop.iterable().to_untyped());
349
350 self.internal.enter();
351 let pattern = for_loop.binding();
352 for ident in pattern.bindings() {
353 self.bind(ident);
354 }
355
356 self.visit(for_loop.body().to_untyped());
357 self.internal.exit();
358 }
359
360 _ => {
361 for child in node.children() {
363 self.visit(child);
364 }
365 }
366 }
367 }
368
369 fn bind(&mut self, ident: Ident) {
370 self.internal.top.bind(
371 ident.get().clone(),
372 Binding::new(Value::unit(), ident.span()),
373 );
374 }
375
376 fn capture(&mut self, ident: Ident<'a>) {
377 if self.internal.get(&ident).is_ok() {
378 return;
380 }
381
382 if self.external.get(&ident).is_ok() {
384 self.captures
385 .entry(ident.get().clone())
386 .or_insert(ident.span());
387 }
388 }
389
390 fn finish(self) -> HashMap<EcoString, Span> {
391 self.captures
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use crate::expression::closure::CapturesVisitor;
398 use crate::test::*;
399 use compose_library::{Scope, Scopes};
400 use compose_syntax::{FileId, parse};
401
402 #[test]
403 fn capturing() {
404 assert_eval(
405 r#"
406 let a = 10;
407 let f = { |a| => a + 1 };
408 assert::eq(f(), 11);
409 "#,
410 );
411
412 assert_eval(
413 r#"
414 let mut a = box::new(5);
415 let f = { |ref mut a| =>
416 *a += 1;
417 *a;
418 };
419 assert::eq(f(), 6);
420 "#,
421 );
422
423 assert_eval(
424 r#"
425 let a = 2;
426 let b = 3;
427 let f = { |a, b| => a * b };
428 assert::eq(f(), 6);
429 "#,
430 );
431
432 assert_eval(
433 r#"
434 let mut name = box::new("Alice");
435 let age = 30;
436 let show = { |ref name, age| =>
437 "Name: " + *name + ", Age: " + age.to_string();
438 };
439 assert::eq(show(), "Name: Alice, Age: 30");
440 *name = "Bob";
441 assert::eq(show(), "Name: Bob, Age: 30");
442 "#,
443 );
444 }
445
446 #[track_caller]
447 fn test(scopes: &Scopes, existing_scope: &Scope, text: &str, result: &[&str]) {
448 let mut visitor = CapturesVisitor::new(scopes, None, existing_scope);
449 let nodes = parse(text, FileId::new("test.comp"));
450 for node in &nodes {
451 visitor.visit(node);
452 }
453
454 let captures = visitor.finish();
455 let mut names: Vec<_> = captures.iter().map(|(k, ..)| k).collect();
456 names.sort();
457
458 assert_eq!(names, result);
459 }
460
461 #[test]
462 fn test_captures_visitor() {
463 let mut scopes = Scopes::new(None);
464 scopes.top.define("f", 0i64);
465 scopes.top.define("x", 0i64);
466 scopes.top.define("y", 0i64);
467 scopes.top.define("z", 0i64);
468 let s = &scopes;
469
470 let mut existing = Scope::new();
471 existing.define("a", 0i64);
472 existing.define("b", 0i64);
473 existing.define("c", 0i64);
474 let e = &existing;
475
476 test(s, e, "{ x => x * 2; }", &[]);
477
478 test(s, e, "let t = x;", &["x"]);
480 test(s, e, "let x = x;", &["x"]);
481 test(s, e, "let x;", &[]);
482 test(s, e, "let x = 2; x + y;", &["y"]);
483 test(s, e, "x + y", &["x", "y"]);
484
485 test(s, e, "x += y;", &["x", "y"]);
487 test(s, e, "x = y;", &["x", "y"]);
488
489 test(s, e, "let f = { => x + y; }", &[]);
492 test(s, e, "let f = { |x| => x + y; }", &["x"]);
494 test(s, e, "let f = { |x| => f(); }", &["x"]);
495 test(s, e, "let f = { x, y, z => f(); }", &[]);
497 test(
499 s,
500 e,
501 "let f = { x = x, y = y, z = z => f(); }",
502 &["x", "y", "z"],
503 );
504
505 test(s, e, "for (x in y) { x + z; };", &["y", "z"]);
507 test(s, e, "for (x in y) { x; }; x", &["x", "y"]);
508
509 test(s, e, "{ x; };", &["x"]);
511 test(s, e, "{ let x; x; };", &[]);
512 test(s, e, "{ let x; x; }; x;", &["x"]);
513
514 test(s, e, "x.y.f(z);", &["x", "z"]);
516
517 test(s, e, "(x + z);", &["x", "z"]);
519 test(s, e, "(({ x => x + y }) + y);", &["y"]);
520 }
521}