compose_library/foundations/iterator/
range_iter.rs

1use crate::diag::SourceResult;
2use crate::Range;
3use compose_library::diag::{bail, StrResult};
4use compose_library::{IntoValue, Str, Value, ValueIterator, Vm};
5use ecow::eco_format;
6use std::sync::{Arc, Mutex};
7
8#[derive(Debug, Clone, PartialEq)]
9enum RangeIterType {
10    Int(i64),
11    Char(u32),
12}
13
14impl RangeIterType {
15    fn nth(&mut self, n: usize) -> Option<RangeIterType> {
16        let cur_offset = n;
17        let next_offset = cur_offset + 1;
18        match self {
19            RangeIterType::Int(i) => {
20                let return_value = *i + cur_offset as i64;
21                *i += next_offset as i64;
22                Some(RangeIterType::Int(return_value))
23            }
24            RangeIterType::Char(c) => {
25                if char::from_u32(*c + cur_offset as u32).is_none() {
26                    return None;
27                };
28                *c += next_offset as u32;
29                Some(RangeIterType::Char(*c))
30            }
31        }
32    }
33}
34
35impl IntoValue for RangeIterType {
36    fn into_value(self) -> Value {
37        match self {
38            RangeIterType::Int(c) => Value::Int(c),
39            RangeIterType::Char(c) => Value::Str(Str::from(eco_format!(
40                "{}",
41                char::from_u32(c).expect("Invalid char code. This is a bug.")
42            ))),
43        }
44    }
45}
46
47#[derive(Debug, Clone)]
48pub struct RangeIter {
49    current: Arc<Mutex<RangeIterType>>,
50    max_inclusive: bool,
51    max: Option<RangeIterType>,
52}
53
54impl PartialEq for RangeIter {
55    fn eq(&self, other: &Self) -> bool {
56        if self.max_inclusive != other.max_inclusive {
57            return false;
58        }
59        if self.max != other.max {
60            return false;
61        }
62        
63        let cur = self.current.lock().unwrap();
64        let other_cur = other.current.lock().unwrap();
65        
66        if *cur != *other_cur {
67            return false;
68        }
69        
70        true
71    }
72}
73
74impl RangeIter {
75    pub fn new(range: &Range) -> StrResult<Self> {
76        let (cur, max, inclusive) = match range {
77            Range::Int(r) => {
78                (r.start.map(RangeIterType::Int), r.end.map(RangeIterType::Int), r.include_end)
79            }
80            Range::Char(r) => {
81                (r.start.map(|char| RangeIterType::Char(char as u32)), r.end.map(|char| RangeIterType::Char(char as u32)), r.include_end)
82            }
83        };
84        
85        let Some(cur) = cur else {
86            bail!("Range iterator must have a start value.");
87        };
88        
89        Ok(Self {
90            current: Arc::new(Mutex::new(cur)),
91            max_inclusive: inclusive,
92            max,
93        })
94    }
95}
96
97impl ValueIterator for RangeIter {
98    fn next(&self, vm: &mut dyn Vm) -> SourceResult<Option<Value>> {
99        self.nth(vm, 0)
100    }
101
102    fn nth(&self, _: &mut dyn Vm<'_>, n: usize) -> SourceResult<Option<Value>> {
103        let mut current = self.current.lock().unwrap();
104        let Some(next) = current.nth(n) else {
105            return Ok(None);
106        };
107
108        let Some(max) = &self.max else {
109            return Ok(Some(next.into_value()));
110        };
111
112        let max_ok = match (&next, max) {
113            (RangeIterType::Char(c), RangeIterType::Char(m)) => {
114                if self.max_inclusive {
115                    c <= m
116                } else {
117                    c < m
118                }
119            }
120            (RangeIterType::Int(c), RangeIterType::Int(m)) => {
121                if self.max_inclusive {
122                    c <= m
123                } else {
124                    c < m
125                }
126            }
127            _ => unreachable!("Invalid range iterator state. This is a bug."),
128        };
129
130        if !max_ok {
131            return Ok(None);
132        }
133
134        Ok(Some(next.into_value()))
135    }
136}