compose_library/foundations/iterator/
mod.rs

1use crate::{ArrayValue, HeapRef, MapValue, Trace};
2use crate::{UntypedRef, Value};
3use compose_library::diag::{SourceResult, bail, error};
4use compose_library::vm::Vm;
5use compose_library::{Func, Str};
6use compose_macros::{func, scope, ty};
7use compose_syntax::Span;
8use std::collections::HashMap;
9use std::fmt::Debug;
10use std::sync::{Arc, Mutex};
11
12mod array_iter;
13mod iter_combinators;
14mod range_iter;
15mod string_iter;
16
17use crate::diag::{SourceDiagnostic, StrResult, UnSpanned};
18use crate::support::eval_func;
19pub use array_iter::*;
20use compose_library::support::eval_predicate;
21pub use iter_combinators::*;
22pub use range_iter::*;
23pub use string_iter::*;
24
25#[ty(scope, cast, name = "Iterator")]
26#[derive(Debug, Clone, PartialEq, Copy)]
27pub struct IterValue {
28    iter: HeapRef<Iter>,
29}
30
31impl Trace for IterValue {
32    fn visit_refs(&self, f: &mut dyn FnMut(UntypedRef)) {
33        f(self.iter.key())
34    }
35}
36
37impl IterValue {
38    pub(crate) fn new(iter: Iter, vm: &mut dyn Vm) -> Self {
39        Self {
40            iter: vm.heap_mut().alloc(iter),
41        }
42    }
43
44    pub fn try_from_value(
45        value: Value,
46        mutable: bool,
47        vm: &mut dyn Vm,
48    ) -> Result<IterValue, UnSpanned<SourceDiagnostic>> {
49        match value {
50            Value::Iterator(i) if mutable => Ok(i),
51
52            Value::Str(_) => Err(error!(
53                Span::detached(), "cannot iterate over a string directly";
54                hint: "try calling `.chars()` to iterate over the characters of the string"
55            )
56            .into()),
57            Value::Box(_) => Err(error!(
58                Span::detached(), "cannot iterate over a boxed value directly";
59                hint: "try dereferencing the box first with `*`"
60            )
61            .into()),
62            Value::Array(arr) => Ok(IterValue::new(
63                Iter::Array(ArrayIter::new(
64                    arr.try_get(vm.heap())
65                        .map_err(|e| SourceDiagnostic::error(Span::detached(), e))?,
66                )),
67                vm,
68            )),
69            Value::Range(range) => Ok(IterValue::new(
70                Iter::Range(
71                    RangeIter::new(range.inner())
72                        .map_err(|e| SourceDiagnostic::error(Span::detached(), e))?,
73                ),
74                vm,
75            )),
76            other @ (Value::Int(_)
77            | Value::Func(_)
78            | Value::Type(_)
79            | Value::Bool(_)
80            | Value::Unit(_)) => {
81                Err(error!(Span::detached(), "cannot iterate over type {}", other.ty()).into())
82            }
83            immut => requires_mutable_iter(immut),
84        }
85    }
86
87    pub(crate) fn shallow_clone(&self, vm: &mut dyn Vm) -> Self {
88        let self_ = self.iter.get_unwrap(vm.heap()).clone();
89
90        Self {
91            iter: vm.heap_mut().alloc(self_),
92        }
93    }
94}
95
96pub fn requires_mutable_iter<T>(value: Value) -> Result<T, UnSpanned<SourceDiagnostic>> {
97    Err(error!(
98        Span::detached(),
99        "cannot iterate over a value of type {} that is not marked as mutable",
100        value.ty()
101    )
102    .into())
103}
104
105#[derive(Clone, Debug, PartialEq)]
106pub enum Iter {
107    String(StringIterator),
108    Array(ArrayIter),
109    Take(TakeIter),
110    TakeWhile(TakeWhileIter),
111    Map(MapIter),
112    Skip(SkipIter),
113    Range(RangeIter),
114    StepBy(StepByIter),
115    Filter(FilterIter),
116}
117
118impl Trace for Iter {
119    fn visit_refs(&self, f: &mut dyn FnMut(UntypedRef)) {
120        match self {
121            Iter::String(_) => {}
122            Iter::TakeWhile(iter) => iter.visit_refs(f),
123            Iter::Take(iter) => iter.visit_refs(f),
124            Iter::Skip(iter) => iter.visit_refs(f),
125            Iter::Map(iter) => iter.visit_refs(f),
126            Iter::Array(arr) => arr.visit_refs(f),
127            Iter::Range(_) => {}
128            Iter::StepBy(iter) => iter.visit_refs(f),
129            Iter::Filter(filter) => filter.visit_refs(f),
130        }
131    }
132}
133
134impl Iter {
135    pub fn next(&self, vm: &mut dyn Vm) -> SourceResult<Option<Value>> {
136        match self {
137            Iter::String(s) => ValueIterator::next(s, vm),
138            Iter::Take(t) => t.next(vm),
139            Iter::TakeWhile(t) => t.next(vm),
140            Iter::Map(m) => m.next(vm),
141            Iter::Skip(s) => s.next(vm),
142            Iter::Array(a) => a.next(vm),
143            Iter::Range(r) => r.next(vm),
144            Iter::StepBy(s) => s.next(vm),
145            Iter::Filter(f) => f.next(vm),
146        }
147    }
148
149    pub fn nth(&self, vm: &mut dyn Vm, n: usize) -> SourceResult<Option<Value>> {
150        match self {
151            Iter::String(s) => ValueIterator::nth(s, vm, n),
152            Iter::Take(t) => t.nth(vm, n),
153            Iter::TakeWhile(t) => t.nth(vm, n),
154            Iter::Map(m) => m.nth(vm, n),
155            Iter::Skip(s) => s.nth(vm, n),
156            Iter::Array(a) => a.nth(vm, n),
157            Iter::Range(r) => r.nth(vm, n),
158            Iter::StepBy(s) => s.nth(vm, n),
159            Iter::Filter(f) => f.nth(vm, n),
160        }
161    }
162}
163
164impl ValueIterator for IterValue {
165    fn next(&self, vm: &mut dyn Vm) -> SourceResult<Option<Value>> {
166        let iter = self.iter.get_unwrap(vm.heap()).clone();
167        iter.next(vm)
168    }
169
170    fn nth(&self, vm: &mut dyn Vm, n: usize) -> SourceResult<Option<Value>> {
171        let iter = self.iter.get_unwrap(vm.heap()).clone();
172        iter.nth(vm, n)
173    }
174}
175
176#[scope]
177impl IterValue {
178    #[func(name = "next")]
179    fn next_(&mut self, vm: &mut dyn Vm) -> SourceResult<Option<Value>> {
180        let iter = self.iter.get_unwrap(vm.heap()).clone();
181
182        iter.next(vm)
183    }
184
185    #[func(name = "nth")]
186    fn nth_(&mut self, vm: &mut dyn Vm, n: usize) -> SourceResult<Option<Value>> {
187        let iter = self.iter.get_unwrap(vm.heap()).clone();
188
189        iter.nth(vm, n)
190    }
191
192    #[func]
193    fn take(self, vm: &mut dyn Vm, n: usize) -> Self {
194        IterValue::new(Iter::Take(TakeIter::new(self, n)), vm)
195    }
196
197    #[func]
198    fn take_while(self, vm: &mut dyn Vm, predicate: Func) -> Self {
199        IterValue::new(
200            Iter::TakeWhile(TakeWhileIter {
201                inner: self,
202                predicate: Arc::new(predicate),
203            }),
204            vm,
205        )
206    }
207
208    #[func]
209    fn map(self, vm: &mut dyn Vm, map: Func) -> Self {
210        IterValue::new(
211            Iter::Map(MapIter {
212                inner: self,
213                map: Arc::new(map),
214            }),
215            vm,
216        )
217    }
218
219    #[func]
220    fn skip(self, vm: &mut dyn Vm, n: usize) -> Self {
221        IterValue::new(
222            Iter::Skip(SkipIter {
223                inner: self,
224                skip: Arc::new(Mutex::new(n)),
225            }),
226            vm,
227        )
228    }
229
230    #[func]
231    fn step_by(self, vm: &mut dyn Vm, step: usize) -> StrResult<Self> {
232        Ok(IterValue::new(
233            Iter::StepBy(StepByIter::new(self, step)?),
234            vm,
235        ))
236    }
237
238    #[func]
239    fn filter(self, vm: &mut dyn Vm, predicate: Func) -> Self {
240        IterValue::new(
241            Iter::Filter(FilterIter {
242                inner: self,
243                predicate: Arc::new(predicate),
244            }),
245            vm,
246        )
247    }
248
249    #[func]
250    fn find(&mut self, vm: &mut dyn Vm, predicate: Func) -> SourceResult<Option<Value>> {
251        while let Some(v) = self.next(vm)? {
252            if eval_predicate(vm, &predicate, v.clone(), "find")? {
253                return Ok(Some(v));
254            }
255        }
256
257        Ok(None)
258    }
259
260    #[func]
261    fn all(&mut self, vm: &mut dyn Vm, predicate: Func) -> SourceResult<bool> {
262        while let Some(v) = self.next(vm)? {
263            if !eval_predicate(vm, &predicate, v.clone(), "all")? {
264                return Ok(false);
265            }
266        }
267
268        Ok(true)
269    }
270
271    #[func]
272    fn any(&mut self, vm: &mut dyn Vm, predicate: Func) -> SourceResult<bool> {
273        while let Some(v) = self.next(vm)? {
274            if eval_predicate(vm, &predicate, v.clone(), "any")? {
275                return Ok(true);
276            }
277        }
278
279        Ok(false)
280    }
281
282    #[func]
283    fn position(&mut self, vm: &mut dyn Vm, predicate: Func) -> SourceResult<Option<usize>> {
284        let mut i = 0;
285        while let Some(v) = self.next(vm)? {
286            if eval_predicate(vm, &predicate, v.clone(), "position")? {
287                return Ok(Some(i));
288            }
289            i += 1;
290        }
291        Ok(None)
292    }
293
294    #[func]
295    fn to_array(self, vm: &mut dyn Vm) -> SourceResult<ArrayValue> {
296        let mut values = Vec::new();
297        while let Some(v) = self.next(vm)? {
298            values.push(v);
299        }
300
301        Ok(ArrayValue::from(vm.heap_mut(), values))
302    }
303
304    #[func]
305    fn to_map(
306        self,
307        vm: &mut dyn Vm,
308        key_mapper: Func,
309        value_mapper: Func,
310    ) -> SourceResult<MapValue> {
311        let mut map = HashMap::new();
312        while let Some(v) = self.next(vm)? {
313            let k = eval_func(vm, &key_mapper, [v.clone()])?;
314
315            let Value::Str(Str(key)) = k else {
316                bail!(key_mapper.span, "key mapper must return a string");
317            };
318
319            let v = eval_func(vm, &value_mapper, [v.clone()])?;
320
321            map.insert(key, v);
322        }
323
324        Ok(MapValue::from(vm.heap_mut(), map))
325    }
326}
327
328pub trait ValueIterator: Debug + Send + Sync {
329    /// Compute the next element of the iterator
330    fn next(&self, vm: &mut dyn Vm) -> SourceResult<Option<Value>>;
331
332    /// Compute the nth element of the iterator.
333    fn nth(&self, vm: &mut dyn Vm, n: usize) -> SourceResult<Option<Value>> {
334        if n == 0 {
335            return self.next(vm);
336        }
337
338        let iter = self;
339        for _ in 0..n {
340            match iter.next(vm)? {
341                Some(_) => {}
342                None => return Ok(None),
343            }
344        }
345
346        iter.next(vm)
347    }
348}
349
350macro_rules! impl_into_iter {
351    (
352        $($iter:ident => $ty:ty),* $(,)?
353    ) => {
354        $(
355            impl Into<Iter> for $ty {
356                fn into(self) -> Iter {
357                    Iter::$iter(self)
358                }
359            }
360        )*
361    };
362}
363
364impl_into_iter!(
365    String => StringIterator,
366    Array => ArrayIter,
367    Take => TakeIter,
368    TakeWhile => TakeWhileIter,
369    Map => MapIter,
370    Skip => SkipIter,
371    Range => RangeIter,
372    StepBy => StepByIter,
373    Filter => FilterIter,
374);