use std::collections::HashMap;
use std::ops::{Range, RangeFrom, RangeTo};
use crate::{
chain::{Chain, Link},
collect::Chains,
node::State,
resource::{AccessFlags, Buffer, Image, Resource},
schedule::{Queue, QueueId, Schedule, SubmissionId},
Id,
};
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct Semaphore {
id: Id,
points: Range<SubmissionId>,
}
impl Semaphore {
fn new(id: Id, points: Range<SubmissionId>) -> Self {
Semaphore { id, points }
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct Signal<S>(S);
impl<S> Signal<S> {
fn new(semaphore: S) -> Self {
Signal(semaphore)
}
pub fn semaphore(&self) -> &S {
&self.0
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct Wait<S>(S, gfx_hal::pso::PipelineStage);
impl<S> Wait<S> {
fn new(semaphore: S, stages: gfx_hal::pso::PipelineStage) -> Self {
Wait(semaphore, stages)
}
pub fn semaphore(&self) -> &S {
&self.0
}
pub fn stage(&self) -> gfx_hal::pso::PipelineStage {
self.1
}
}
#[derive(Clone, Debug)]
pub struct Barrier<R: Resource> {
pub families: Option<Range<gfx_hal::queue::QueueFamilyId>>,
pub states: Range<(R::Access, R::Layout, gfx_hal::pso::PipelineStage)>,
}
impl<R> Barrier<R>
where
R: Resource,
{
fn new(states: Range<State<R>>) -> Self {
Barrier {
families: None,
states: (
states.start.access,
states.start.layout,
states.start.stages,
)..(states.end.access, states.end.layout, states.end.stages),
}
}
fn transfer(
families: Range<gfx_hal::queue::QueueFamilyId>,
states: Range<(R::Access, R::Layout)>,
) -> Self {
Barrier {
families: Some(families),
states: (
states.start.0,
states.start.1,
gfx_hal::pso::PipelineStage::TOP_OF_PIPE,
)
..(
states.end.0,
states.end.1,
gfx_hal::pso::PipelineStage::BOTTOM_OF_PIPE,
),
}
}
fn acquire(
families: Range<gfx_hal::queue::QueueFamilyId>,
left: RangeFrom<R::Layout>,
right: RangeTo<(R::Access, R::Layout)>,
) -> Self {
Self::transfer(
families,
(R::Access::empty(), left.start)..(right.end.0, right.end.1),
)
}
fn release(
families: Range<gfx_hal::queue::QueueFamilyId>,
left: RangeFrom<(R::Access, R::Layout)>,
right: RangeTo<R::Layout>,
) -> Self {
Self::transfer(
families,
(left.start.0, left.start.1)..(R::Access::empty(), right.end),
)
}
}
pub type Barriers<R> = HashMap<Id, Barrier<R>>;
pub type BufferBarriers = Barriers<Buffer>;
pub type ImageBarriers = Barriers<Image>;
#[derive(Clone, Debug)]
pub struct Guard {
pub buffers: BufferBarriers,
pub images: ImageBarriers,
}
impl Guard {
fn new() -> Self {
Guard {
buffers: HashMap::default(),
images: HashMap::default(),
}
}
fn pick<R: Resource>(&mut self) -> &mut Barriers<R> {
use std::any::Any;
let Guard {
ref mut buffers,
ref mut images,
} = *self;
Any::downcast_mut(buffers)
.or_else(move || Any::downcast_mut(images))
.expect("`R` should be `Buffer` or `Image`")
}
}
#[derive(Clone, Debug)]
pub struct SyncData<S, W> {
pub wait: Vec<Wait<W>>,
pub acquire: Guard,
pub release: Guard,
pub signal: Vec<Signal<S>>,
}
impl<S, W> SyncData<S, W> {
fn new() -> Self {
SyncData {
wait: Vec::new(),
acquire: Guard::new(),
release: Guard::new(),
signal: Vec::new(),
}
}
fn convert_signal<F, T>(self, mut f: F) -> SyncData<T, W>
where
F: FnMut(S) -> T,
{
SyncData {
wait: self.wait,
acquire: Guard {
buffers: self.acquire.buffers,
images: self.acquire.images,
},
release: Guard {
buffers: self.release.buffers,
images: self.release.images,
},
signal: self
.signal
.into_iter()
.map(|Signal(semaphore)| Signal(f(semaphore)))
.collect(),
}
}
fn convert_wait<F, T>(self, mut f: F) -> SyncData<S, T>
where
F: FnMut(W) -> T,
{
SyncData {
wait: self
.wait
.into_iter()
.map(|Wait(semaphore, stage)| Wait(f(semaphore), stage))
.collect(),
acquire: Guard {
buffers: self.acquire.buffers,
images: self.acquire.images,
},
release: Guard {
buffers: self.release.buffers,
images: self.release.images,
},
signal: self.signal,
}
}
}
struct SyncTemp(HashMap<SubmissionId, SyncData<Semaphore, Semaphore>>);
impl SyncTemp {
fn get_sync(&mut self, sid: SubmissionId) -> &mut SyncData<Semaphore, Semaphore> {
self.0.entry(sid).or_insert_with(|| SyncData::new())
}
}
pub fn sync<F, S, W>(chains: &Chains, mut new_semaphore: F) -> Schedule<SyncData<S, W>>
where
F: FnMut() -> (S, W),
{
let ref schedule = chains.schedule;
let ref buffers = chains.buffers;
let ref images = chains.images;
let mut sync = SyncTemp(HashMap::default());
for (&id, chain) in buffers {
sync_chain(id, chain, schedule, &mut sync);
}
for (&id, chain) in images {
sync_chain(id, chain, schedule, &mut sync);
}
if schedule.queue_count() > 1 {
optimize(schedule, &mut sync);
}
let mut result = Schedule::new();
let mut signals: HashMap<Semaphore, Option<S>> = HashMap::default();
let mut waits: HashMap<Semaphore, Option<W>> = HashMap::default();
for queue in schedule.iter().flat_map(|family| family.iter()) {
let mut new_queue = Queue::new(queue.id());
for submission in queue.iter() {
let sync = if let Some(sync) = sync.0.remove(&submission.id()) {
let sync = sync.convert_signal(|semaphore| match signals.get_mut(&semaphore) {
None => {
let (signal, wait) = new_semaphore();
let old = waits.insert(semaphore, Some(wait));
assert!(old.is_none());
signal
}
Some(signal) => signal.take().unwrap(),
});
let sync = sync.convert_wait(|semaphore| match waits.get_mut(&semaphore) {
None => {
let (signal, wait) = new_semaphore();
let old = signals.insert(semaphore, Some(signal));
assert!(old.is_none());
wait
}
Some(wait) => wait.take().unwrap(),
});
sync
} else {
SyncData::new()
};
new_queue.add_submission_checked(submission.set_sync(sync));
}
result.set_queue(new_queue);
}
debug_assert!(sync.0.is_empty());
debug_assert!(signals.values().all(|x| x.is_none()));
debug_assert!(waits.values().all(|x| x.is_none()));
result
}
fn latest<R, S>(link: &Link<R>, schedule: &Schedule<S>) -> SubmissionId
where
R: Resource,
{
let (_, sid) = link
.queues()
.map(|(qid, queue)| {
let sid = SubmissionId::new(qid, queue.last);
(schedule[sid].submit_order(), sid)
})
.max_by_key(|&(submit_order, sid)| (submit_order, sid.queue().index()))
.unwrap();
sid
}
fn earliest<R, S>(link: &Link<R>, schedule: &Schedule<S>) -> SubmissionId
where
R: Resource,
{
let (_, sid) = link
.queues()
.map(|(qid, queue)| {
let sid = SubmissionId::new(qid, queue.first);
(schedule[sid].submit_order(), sid)
})
.min_by_key(|&(submit_order, sid)| (submit_order, sid.queue().index()))
.unwrap();
sid
}
fn generate_semaphore_pair<R: Resource>(
sync: &mut SyncTemp,
id: Id,
link: &Link<R>,
range: Range<SubmissionId>,
) {
if range.start.queue() != range.end.queue() {
let semaphore = Semaphore::new(id, range.clone());
sync.get_sync(range.start)
.signal
.push(Signal::new(semaphore.clone()));
sync.get_sync(range.end)
.wait
.push(Wait::new(semaphore, link.queue(range.end.queue()).stages));
}
}
fn sync_chain<R, S>(id: Id, chain: &Chain<R>, schedule: &Schedule<S>, sync: &mut SyncTemp)
where
R: Resource,
{
let uid = id.into();
let pairs = chain
.links()
.windows(2)
.map(|pair| (&pair[0], &pair[1]))
.chain(
chain
.links()
.first()
.and_then(|first| chain.links().last().map(move |last| (last, first))),
);
for (prev_link, link) in pairs {
log::trace!("Sync {:#?}:{:#?}", prev_link.access(), link.access());
if prev_link.family() == link.family() {
if prev_link.access().exclusive() && !link.access().exclusive() {
let signal_sid = latest(prev_link, schedule);
sync.get_sync(signal_sid)
.release
.pick::<R>()
.insert(id, Barrier::new(prev_link.state()..link.state()));
for (queue_id, queue) in link.queues() {
let head = SubmissionId::new(queue_id, queue.first);
generate_semaphore_pair(sync, uid, link, signal_sid..head);
}
} else {
let wait_sid = earliest(link, schedule);
for (queue_id, queue) in prev_link.queues() {
let tail = SubmissionId::new(queue_id, queue.last);
generate_semaphore_pair(sync, uid, link, tail..wait_sid);
}
sync.get_sync(wait_sid)
.acquire
.pick()
.insert(id, Barrier::new(prev_link.state()..link.state()));
if !link.access().exclusive() {
unimplemented!("This case is unimplemented");
}
}
} else {
let signal_sid = latest(prev_link, schedule);
let wait_sid = earliest(link, schedule);
if !prev_link.access().exclusive() {
unimplemented!("This case is unimplemented");
}
generate_semaphore_pair(sync, uid, link, signal_sid..wait_sid);
sync.get_sync(signal_sid).release.pick::<R>().insert(
id,
Barrier::release(
signal_sid.family()..wait_sid.family(),
(prev_link.access(), prev_link.layout())..,
..link.layout(),
),
);
sync.get_sync(wait_sid).acquire.pick::<R>().insert(
id,
Barrier::acquire(
signal_sid.family()..wait_sid.family(),
prev_link.layout()..,
..(link.access(), link.layout()),
),
);
if !link.access().exclusive() {
unimplemented!("This case is unimplemented");
}
}
}
}
fn optimize_submission(
sid: SubmissionId,
found: &mut HashMap<QueueId, usize>,
sync: &mut SyncTemp,
) {
let mut to_remove = Vec::new();
if let Some(sync_data) = sync.0.get_mut(&sid) {
sync_data
.wait
.sort_unstable_by_key(|wait| (wait.stage(), wait.semaphore().points.end.index()));
sync_data.wait.retain(|wait| {
let start = wait.semaphore().points.start;
if let Some(synched_to) = found.get_mut(&start.queue()) {
if *synched_to >= start.index() {
to_remove.push(wait.semaphore().clone());
return false;
} else {
*synched_to = start.index();
return true;
}
}
found.insert(start.queue(), start.index());
true
});
} else {
return;
}
for semaphore in to_remove.drain(..) {
let ref mut signal = sync.0.get_mut(&semaphore.points.start).unwrap().signal;
let index = signal
.iter()
.position(|signal| signal.0 == semaphore)
.unwrap();
signal.swap_remove(index);
}
}
fn optimize<S>(schedule: &Schedule<S>, sync: &mut SyncTemp) {
for queue in schedule.iter().flat_map(|family| family.iter()) {
let mut found = HashMap::default();
for submission in queue.iter() {
optimize_submission(submission.id(), &mut found, sync);
}
}
}