1use crate::model::{Mesh, Triangle};
2use glam::Vec3;
3
4#[derive(Debug, Clone, Copy)]
6pub struct AABB {
7 pub min: Vec3,
9 pub max: Vec3,
11}
12
13impl AABB {
14 pub fn from_triangle(mesh: &Mesh, tri: &Triangle) -> Self {
16 let v1 = mesh.vertices[tri.v1 as usize];
17 let v2 = mesh.vertices[tri.v2 as usize];
18 let v3 = mesh.vertices[tri.v3 as usize];
19
20 let min = Vec3::new(
21 v1.x.min(v2.x).min(v3.x),
22 v1.y.min(v2.y).min(v3.y),
23 v1.z.min(v2.z).min(v3.z),
24 );
25 let max = Vec3::new(
26 v1.x.max(v2.x).max(v3.x),
27 v1.y.max(v2.y).max(v3.y),
28 v1.z.max(v2.z).max(v3.z),
29 );
30 Self { min, max }
31 }
32
33 pub fn intersects(&self, other: &Self) -> bool {
35 self.min.x <= other.max.x
36 && self.max.x >= other.min.x
37 && self.min.y <= other.max.y
38 && self.max.y >= other.min.y
39 && self.min.z <= other.max.z
40 && self.max.z >= other.min.z
41 }
42}
43
44pub struct BvhNode {
46 pub aabb: AABB,
48 pub content: BvhContent,
50}
51
52pub enum BvhContent {
54 Leaf(Vec<usize>), Branch(Box<BvhNode>, Box<BvhNode>),
58}
59
60impl BvhNode {
61 pub fn build(mesh: &Mesh, tri_indices: Vec<usize>) -> Self {
63 let aabbs: Vec<AABB> = tri_indices
64 .iter()
65 .map(|&i| AABB::from_triangle(mesh, &mesh.triangles[i]))
66 .collect();
67
68 let mut min = Vec3::splat(f32::INFINITY);
69 let mut max = Vec3::splat(f32::NEG_INFINITY);
70 for aabb in &aabbs {
71 min = min.min(aabb.min);
72 max = max.max(aabb.max);
73 }
74
75 let node_aabb = AABB { min, max };
76
77 if tri_indices.len() <= 8 {
78 return BvhNode {
79 aabb: node_aabb,
80 content: BvhContent::Leaf(tri_indices),
81 };
82 }
83
84 let size = max - min;
86 let axis = if size.x > size.y && size.x > size.z {
87 0
88 } else if size.y > size.z {
89 1
90 } else {
91 2
92 };
93
94 let mid = (min[axis] + max[axis]) / 2.0;
95
96 let mut left_indices = Vec::new();
97 let mut right_indices = Vec::new();
98
99 for (i, &tri_idx) in tri_indices.iter().enumerate() {
100 let center = (aabbs[i].min[axis] + aabbs[i].max[axis]) / 2.0;
101 if center < mid {
102 left_indices.push(tri_idx);
103 } else {
104 right_indices.push(tri_idx);
105 }
106 }
107
108 if left_indices.is_empty() || right_indices.is_empty() {
110 return BvhNode {
111 aabb: node_aabb,
112 content: BvhContent::Leaf(tri_indices),
113 };
114 }
115
116 BvhNode {
117 aabb: node_aabb,
118 content: BvhContent::Branch(
119 Box::new(BvhNode::build(mesh, left_indices)),
120 Box::new(BvhNode::build(mesh, right_indices)),
121 ),
122 }
123 }
124
125 pub fn find_intersections(
127 &self,
128 mesh: &Mesh,
129 tri_idx: usize,
130 tri_aabb: &AABB,
131 results: &mut Vec<usize>,
132 ) {
133 if !self.aabb.intersects(tri_aabb) {
134 return;
135 }
136
137 match &self.content {
138 BvhContent::Leaf(indices) => {
139 for &idx in indices {
140 if idx > tri_idx {
141 if tri_aabb.intersects(&AABB::from_triangle(mesh, &mesh.triangles[idx])) {
143 if intersect_triangles(mesh, tri_idx, idx) {
145 results.push(idx);
146 }
147 }
148 }
149 }
150 }
151 BvhContent::Branch(left, right) => {
152 left.find_intersections(mesh, tri_idx, tri_aabb, results);
153 right.find_intersections(mesh, tri_idx, tri_aabb, results);
154 }
155 }
156 }
157}
158
159fn intersect_triangles(mesh: &Mesh, i1: usize, i2: usize) -> bool {
161 let t1 = &mesh.triangles[i1];
162 let t2 = &mesh.triangles[i2];
163
164 let shared = count_shared_vertices(t1, t2);
170 if shared >= 2 {
171 return false;
172 }
173
174 let p1 = to_vec3(mesh.vertices[t1.v1 as usize]);
175 let p2 = to_vec3(mesh.vertices[t1.v2 as usize]);
176 let p3 = to_vec3(mesh.vertices[t1.v3 as usize]);
177
178 let q1 = to_vec3(mesh.vertices[t2.v1 as usize]);
179 let q2 = to_vec3(mesh.vertices[t2.v2 as usize]);
180 let q3 = to_vec3(mesh.vertices[t2.v3 as usize]);
181
182 tri_tri_intersect(p1, p2, p3, q1, q2, q3)
183}
184
185fn to_vec3(v: crate::model::Vertex) -> Vec3 {
186 Vec3::new(v.x, v.y, v.z)
187}
188
189fn count_shared_vertices(t1: &Triangle, t2: &Triangle) -> usize {
190 let mut count = 0;
191 let v1 = [t1.v1, t1.v2, t1.v3];
192 let v2 = [t2.v1, t2.v2, t2.v3];
193 for &va in &v1 {
194 for &vb in &v2 {
195 if va == vb {
196 count += 1;
197 }
198 }
199 }
200 count
201}
202
203fn tri_tri_intersect(p1: Vec3, p2: Vec3, p3: Vec3, q1: Vec3, q2: Vec3, q3: Vec3) -> bool {
208 let n2 = (q2 - q1).cross(q3 - q1);
210 if n2.length_squared() < 1e-12 {
211 return false;
212 } let d2 = -n2.dot(q1);
214
215 let du0 = n2.dot(p1) + d2;
217 let du1 = n2.dot(p2) + d2;
218 let du2 = n2.dot(p3) + d2;
219
220 if (du0.abs() > 1e-6 && du1.abs() > 1e-6 && du2.abs() > 1e-6)
221 && ((du0 > 0.0 && du1 > 0.0 && du2 > 0.0) || (du0 < 0.0 && du1 < 0.0 && du2 < 0.0))
222 {
223 return false; }
225
226 let n1 = (p2 - p1).cross(p3 - p1);
228 if n1.length_squared() < 1e-12 {
229 return false;
230 } let d1 = -n1.dot(p1);
232
233 let dv0 = n1.dot(q1) + d1;
235 let dv1 = n1.dot(q2) + d1;
236 let dv2 = n1.dot(q3) + d1;
237
238 if (dv0.abs() > 1e-6 && dv1.abs() > 1e-6 && dv2.abs() > 1e-6)
239 && ((dv0 > 0.0 && dv1 > 0.0 && dv2 > 0.0) || (dv0 < 0.0 && dv1 < 0.0 && dv2 < 0.0))
240 {
241 return false; }
243
244 let ld = n1.cross(n2);
246 let index = if ld.x.abs() > ld.y.abs() && ld.x.abs() > ld.z.abs() {
247 0
248 } else if ld.y.abs() > ld.z.abs() {
249 1
250 } else {
251 2
252 };
253
254 let get_interval =
256 |v1: Vec3, v2: Vec3, v3: Vec3, d1: f32, d2: f32, d3: f32| -> Option<(f32, f32)> {
257 if (d1 > 0.0 && d2 > 0.0 && d3 > 0.0) || (d1 < 0.0 && d2 < 0.0 && d3 < 0.0) {
259 return None;
260 }
261
262 let mut pts = Vec::new();
263 let tris = [(v1, v2, d1, d2), (v2, v3, d2, d3), (v3, v1, d3, d1)];
264 for (a, b, da, db) in tris {
265 if (da >= 0.0) != (db >= 0.0) {
266 let t = da / (da - db);
267 let p = a + (b - a) * t;
268 pts.push(p[index]);
269 } else if da.abs() < 1e-7 {
270 pts.push(a[index]);
271 }
272 }
273 if pts.len() < 2 {
274 return None;
275 }
276 let mut min = pts[0];
277 let mut max = pts[0];
278 for &p in &pts {
279 min = min.min(p);
280 max = max.max(p);
281 }
282 Some((min, max))
283 };
284
285 let i1 = get_interval(p1, p2, p3, du0, du1, du2);
286 let i2 = get_interval(q1, q2, q3, dv0, dv1, dv2);
287
288 match (i1, i2) {
289 (Some((t1_min, t1_max)), Some((t2_min, t2_max))) => {
290 t1_min + 1e-6 < t2_max && t2_min + 1e-6 < t1_max
293 }
294 _ => false,
295 }
296}