compose_library/foundations/iterator/
iter_combinators.rs

1use crate::IterValue;
2use crate::diag::StrResult;
3use compose_library::diag::{At, SourceResult, bail};
4use compose_library::vm::Vm;
5use compose_library::{Args, Func, Trace, UntypedRef, Value, ValueIterator};
6use std::iter;
7use std::ops::DerefMut;
8use std::sync::{Arc, Mutex};
9use compose_library::support::eval_predicate;
10
11#[derive(Debug, Clone)]
12pub struct TakeIter {
13    pub(crate) inner: IterValue,
14    pub(crate) take: Arc<Mutex<usize>>,
15}
16
17impl TakeIter {
18    pub fn new(inner: IterValue, take: usize) -> Self {
19        Self {
20            inner,
21            take: Arc::new(Mutex::new(take)),
22        }
23    }
24}
25
26impl PartialEq for TakeIter {
27    fn eq(&self, other: &Self) -> bool {
28        if self.inner != other.inner {
29            return false;
30        }
31
32        let take_a = self.take.lock().expect("mutex poisoned");
33        let take_b = other.take.lock().expect("mutex poisoned");
34
35        if *take_a != *take_b {
36            return false;
37        }
38
39        true
40    }
41}
42
43impl Trace for TakeIter {
44    fn visit_refs(&self, f: &mut dyn FnMut(UntypedRef)) {
45        self.inner.visit_refs(f);
46    }
47}
48
49impl ValueIterator for TakeIter {
50    fn next(&self, vm: &mut dyn Vm) -> SourceResult<Option<Value>> {
51        self.nth(vm, 0)
52    }
53
54    fn nth(&self, vm: &mut dyn Vm, n: usize) -> SourceResult<Option<Value>> {
55        let mut take = self.take.lock().expect("take poisoned");
56        if *take <= n {
57            *take = 0;
58            return Ok(None);
59        }
60
61        *take = *take - n - 1; // minus one because n is 0 indexed. (for 0 we do yield an item, so take should be decremented)
62        drop(take);
63
64        self.inner.nth(vm, n)
65    }
66}
67
68#[derive(Debug, Clone)]
69pub struct SkipIter {
70    pub(crate) inner: IterValue,
71    pub(crate) skip: Arc<Mutex<usize>>,
72}
73
74impl PartialEq for SkipIter {
75    fn eq(&self, other: &Self) -> bool {
76        if self.inner != other.inner {
77            return false;
78        }
79
80        {
81            let skip_a = self.skip.lock().expect("mutex poisoned");
82            let skip_b = self.skip.lock().expect("mutex poisoned");
83
84            if *skip_a != *skip_b {
85                return false;
86            }
87        }
88
89        true
90    }
91}
92
93impl Trace for SkipIter {
94    fn visit_refs(&self, f: &mut dyn FnMut(UntypedRef)) {
95        self.inner.visit_refs(f)
96    }
97}
98
99impl ValueIterator for SkipIter {
100    fn next(&self, vm: &mut dyn Vm) -> SourceResult<Option<Value>> {
101        let mut skip = self.skip.lock().expect("Skip lock poisoned");
102        let n = std::mem::replace(skip.deref_mut(), 0);
103        self.inner.nth(vm, n)
104    }
105
106    fn nth(&self, vm: &mut dyn Vm, n: usize) -> SourceResult<Option<Value>> {
107        let mut skip = self.skip.lock().expect("Skip lock poisoned");
108        let to_skip = std::mem::replace(skip.deref_mut(), 0);
109        self.inner.nth(vm, to_skip + n)
110    }
111}
112
113#[derive(Debug, Clone, PartialEq)]
114pub struct TakeWhileIter {
115    pub(crate) inner: IterValue,
116    pub(crate) predicate: Arc<Func>,
117}
118
119impl Trace for TakeWhileIter {
120    fn visit_refs(&self, f: &mut dyn FnMut(UntypedRef)) {
121        self.inner.visit_refs(f);
122        self.predicate.visit_refs(f);
123    }
124}
125
126impl ValueIterator for TakeWhileIter {
127    fn next(&self, vm: &mut dyn Vm) -> SourceResult<Option<Value>> {
128        let item = match self.inner.next(vm) {
129            Ok(Some(item)) => item,
130            Ok(None) => return Ok(None),
131            Err(err) => return Err(err),
132        };
133
134        let span = self.predicate.span();
135        let args = Args::new(span, iter::once(item.clone()));
136
137        if self
138            .predicate
139            .call(vm, args)?
140            .cast::<bool>()
141            .at(span)
142            .map_err(|mut err| {
143                err.make_mut()[0].hint("predicate must return a boolean");
144                err
145            })?
146        {
147            Ok(Some(item))
148        } else {
149            Ok(None)
150        }
151    }
152
153    // nth method cannot be optimized here, so we just fall back to the default
154}
155
156#[derive(Debug, Clone, PartialEq)]
157pub struct FilterIter {
158    pub(crate) inner: IterValue,
159    pub(crate) predicate: Arc<Func>,
160}
161
162impl Trace for FilterIter {
163    fn visit_refs(&self, f: &mut dyn FnMut(UntypedRef)) {
164        self.inner.visit_refs(f);
165        self.predicate.visit_refs(f);
166    }
167}
168
169impl ValueIterator for FilterIter {
170    fn next(&self, vm: &mut dyn Vm) -> SourceResult<Option<Value>> {
171        loop {
172            let item = match self.inner.next(vm, ) {
173                Ok(Some(item)) => item,
174                Ok(None) => return Ok(None),
175                Err(err) => return Err(err),
176            };
177
178            if eval_predicate(vm, &self.predicate, item.clone(), "filter")? {
179                return Ok(Some(item));
180            }
181        }
182    }
183
184    // Filter is not nth optimizable because it needs to evaluate predicate for each item it examines
185    // both for potential side effects and for keeping track of which values it would have yielded.
186    // Fall back to the default implementation
187}
188
189#[derive(Debug, Clone, PartialEq)]
190pub struct MapIter {
191    pub(crate) inner: IterValue,
192    pub(crate) map: Arc<Func>,
193}
194
195impl Trace for MapIter {
196    fn visit_refs(&self, f: &mut dyn FnMut(UntypedRef)) {
197        self.inner.visit_refs(f);
198        self.map.visit_refs(f);
199    }
200}
201
202impl ValueIterator for MapIter {
203    fn next(&self, vm: &mut dyn Vm) -> SourceResult<Option<Value>> {
204        self.nth(vm, 0)
205    }
206
207    fn nth(&self, vm: &mut dyn Vm, n: usize) -> SourceResult<Option<Value>> {
208        self.inner
209            .nth(vm, n)?
210            .map(|item| {
211                let span = self.map.span();
212                let args = Args::new(span, iter::once(item.clone()));
213
214                self.map.call(vm, args)
215            })
216            .transpose()
217    }
218}
219
220impl StepByIter {
221    pub fn new(inner: IterValue, step: usize) -> StrResult<Self> {
222        if step == 0 {
223            bail!("step must be greater than 0");
224        }
225
226        Ok(Self {
227            inner,
228            step,
229            first_step: Arc::new(Mutex::new(true)),
230        })
231    }
232}
233
234impl PartialEq for StepByIter {
235    fn eq(&self, other: &Self) -> bool {
236        if self.step != other.step {
237            return false;
238        }
239
240        if *self.first_step.lock().unwrap() != *other.first_step.lock().unwrap() {
241            return false;
242        }
243
244        if self.inner != other.inner {
245            return false;
246        }
247
248        true
249    }
250}
251
252#[derive(Debug, Clone)]
253pub struct StepByIter {
254    inner: IterValue,
255    step: usize,
256    first_step: Arc<Mutex<bool>>,
257}
258
259impl ValueIterator for StepByIter {
260    fn next(&self, vm: &mut dyn Vm) -> SourceResult<Option<Value>> {
261        self.nth(vm, 0)
262    }
263
264    fn nth(&self, vm: &mut dyn Vm, n: usize) -> SourceResult<Option<Value>> {
265        let mut first_step = self.first_step.lock().unwrap();
266        let first_step = std::mem::replace(first_step.deref_mut(), false);
267
268        // Compute the number of elements to skip in the underlying iterator.
269        //
270        // If this is the first call and the first element hasn't been yielded yet,
271        // we must consume it (index 0) before starting to step.
272        //
273        // For step = 2:
274        // - first_step = true:
275        //   nth(1) should yield index: 2 * 1 + 1 = 3
276        // - first_step = false:
277        //   nth(1) should yield index: 2 * (1 + 1) = 4
278        //
279        // This ensures consistent stepping after the initial yield.
280        let offset = if first_step {
281            self.step * n
282        } else {
283            self.step * (n + 1) - 1
284        };
285
286        self.inner.nth(vm, offset)
287    }
288}
289
290impl Trace for StepByIter {
291    fn visit_refs(&self, f: &mut dyn FnMut(UntypedRef)) {
292        self.inner.visit_refs(f);
293    }
294}