#![warn(
missing_debug_implementations,
missing_copy_implementations,
missing_docs,
trivial_casts,
trivial_numeric_casts,
unused_extern_crates,
unused_import_braces,
unused_qualifications
)]
#[cfg(feature = "shader-compiler")]
mod shaderc;
#[cfg(feature = "spirv-reflection")]
#[allow(dead_code)]
mod reflect;
#[cfg(feature = "shader-compiler")]
pub use self::shaderc::*;
#[cfg(feature = "spirv-reflection")]
pub use self::reflect::SpirvReflection;
use gfx_hal::{pso::ShaderStageFlags, Backend};
use std::collections::HashMap;
pub trait Shader {
fn spirv(&self) -> Result<std::borrow::Cow<'_, [u8]>, failure::Error>;
fn entry(&self) -> &str;
fn stage(&self) -> ShaderStageFlags;
unsafe fn module<B>(
&self,
factory: &rendy_factory::Factory<B>,
) -> Result<B::ShaderModule, failure::Error>
where
B: Backend,
{
gfx_hal::Device::create_shader_module(factory.device().raw(), &self.spirv()?)
.map_err(Into::into)
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SpirvShader {
#[cfg_attr(feature = "serde", serde(with = "serde_bytes"))]
spirv: Vec<u8>,
stage: ShaderStageFlags,
entry: String,
}
impl SpirvShader {
pub fn new(spirv: Vec<u8>, stage: ShaderStageFlags, entrypoint: &str) -> Self {
assert!(!spirv.is_empty());
assert_eq!(spirv.len() % 4, 0);
Self {
spirv,
stage,
entry: entrypoint.to_string(),
}
}
}
impl Shader for SpirvShader {
fn spirv(&self) -> Result<std::borrow::Cow<'_, [u8]>, failure::Error> {
Ok(std::borrow::Cow::Borrowed(&self.spirv))
}
fn entry(&self) -> &str {
&self.entry
}
fn stage(&self) -> ShaderStageFlags {
self.stage
}
}
#[derive(derivative::Derivative, Debug)]
#[derivative(Default(bound = ""))]
pub struct ShaderSet<B: Backend> {
shaders: HashMap<ShaderStageFlags, ShaderStorage<B>>,
}
impl<B: Backend> ShaderSet<B> {
pub fn load(
&mut self,
factory: &rendy_factory::Factory<B>,
) -> Result<&mut Self, failure::Error> {
for (_, v) in self.shaders.iter_mut() {
unsafe { v.compile(factory)? }
}
Ok(self)
}
pub fn raw<'a>(&'a self) -> Result<(gfx_hal::pso::GraphicsShaderSet<'a, B>), failure::Error> {
Ok(gfx_hal::pso::GraphicsShaderSet {
vertex: self
.shaders
.get(&ShaderStageFlags::VERTEX)
.unwrap()
.get_entry_point()?
.unwrap(),
fragment: match self.shaders.get(&ShaderStageFlags::FRAGMENT) {
Some(fragment) => fragment.get_entry_point()?,
None => None,
},
domain: match self.shaders.get(&ShaderStageFlags::DOMAIN) {
Some(domain) => domain.get_entry_point()?,
None => None,
},
hull: match self.shaders.get(&ShaderStageFlags::HULL) {
Some(hull) => hull.get_entry_point()?,
None => None,
},
geometry: match self.shaders.get(&ShaderStageFlags::GEOMETRY) {
Some(geometry) => geometry.get_entry_point()?,
None => None,
},
})
}
pub fn dispose(&mut self, factory: &rendy_factory::Factory<B>) {
for (_, shader) in self.shaders.iter_mut() {
shader.dispose(factory);
}
}
}
#[derive(Debug, Default, Clone)]
#[allow(missing_copy_implementations)]
pub struct SpecConstantSet {
pub vertex: Option<gfx_hal::pso::Specialization<'static>>,
pub fragment: Option<gfx_hal::pso::Specialization<'static>>,
pub geometry: Option<gfx_hal::pso::Specialization<'static>>,
pub hull: Option<gfx_hal::pso::Specialization<'static>>,
pub domain: Option<gfx_hal::pso::Specialization<'static>>,
pub compute: Option<gfx_hal::pso::Specialization<'static>>,
}
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ShaderSetBuilder {
vertex: Option<(Vec<u8>, String)>,
fragment: Option<(Vec<u8>, String)>,
geometry: Option<(Vec<u8>, String)>,
hull: Option<(Vec<u8>, String)>,
domain: Option<(Vec<u8>, String)>,
compute: Option<(Vec<u8>, String)>,
}
impl ShaderSetBuilder {
pub fn build<B: Backend>(
&self,
factory: &rendy_factory::Factory<B>,
spec_constants: SpecConstantSet,
) -> Result<ShaderSet<B>, failure::Error> {
let mut set = ShaderSet::<B>::default();
if self.vertex.is_none() && self.compute.is_none() {
failure::bail!("A vertex or compute shader must be provided");
}
let create_storage = move |stage,
shader: (Vec<u8>, String, Option<gfx_hal::pso::Specialization<'static>>),
factory|
-> Result<ShaderStorage<B>, failure::Error> {
let mut storage = ShaderStorage {
stage: stage,
spirv: shader.0,
module: None,
entrypoint: shader.1.clone(),
specialization: shader.2,
};
unsafe {
storage.compile(factory)?;
}
Ok(storage)
};
if let Some(shader) = self.vertex.clone() {
set.shaders.insert(
ShaderStageFlags::VERTEX,
create_storage(ShaderStageFlags::VERTEX, (shader.0, shader.1, spec_constants.vertex), factory)?,
);
}
if let Some(shader) = self.fragment.clone() {
set.shaders.insert(
ShaderStageFlags::FRAGMENT,
create_storage(ShaderStageFlags::FRAGMENT, (shader.0, shader.1, spec_constants.fragment), factory)?,
);
}
if let Some(shader) = self.compute.clone() {
set.shaders.insert(
ShaderStageFlags::COMPUTE,
create_storage(ShaderStageFlags::COMPUTE, (shader.0, shader.1, spec_constants.compute), factory)?,
);
}
if let Some(shader) = self.domain.clone() {
set.shaders.insert(
ShaderStageFlags::DOMAIN,
create_storage(ShaderStageFlags::DOMAIN, (shader.0, shader.1, spec_constants.domain), factory)?,
);
}
if let Some(shader) = self.hull.clone() {
set.shaders.insert(
ShaderStageFlags::HULL,
create_storage(ShaderStageFlags::HULL, (shader.0, shader.1, spec_constants.hull), factory)?,
);
}
if let Some(shader) = self.geometry.clone() {
set.shaders.insert(
ShaderStageFlags::GEOMETRY,
create_storage(ShaderStageFlags::GEOMETRY, (shader.0, shader.1, spec_constants.geometry), factory)?,
);
}
Ok(set)
}
#[inline(always)]
pub fn with_vertex<S: Shader>(mut self, shader: &S) -> Result<Self, failure::Error> {
let data = shader.spirv()?;
self.vertex = Some((data.to_vec(), shader.entry().to_string()));
Ok(self)
}
#[inline(always)]
pub fn with_fragment<S: Shader>(mut self, shader: &S) -> Result<Self, failure::Error> {
let data = shader.spirv()?;
self.fragment = Some((data.to_vec(), shader.entry().to_string()));
Ok(self)
}
#[inline(always)]
pub fn with_geometry<S: Shader>(mut self, shader: &S) -> Result<Self, failure::Error> {
let data = shader.spirv()?;
self.geometry = Some((data.to_vec(), shader.entry().to_string()));
Ok(self)
}
#[inline(always)]
pub fn with_hull<S: Shader>(mut self, shader: &S) -> Result<Self, failure::Error> {
let data = shader.spirv()?;
self.hull = Some((data.to_vec(), shader.entry().to_string()));
Ok(self)
}
#[inline(always)]
pub fn with_domain<S: Shader>(mut self, shader: &S) -> Result<Self, failure::Error> {
let data = shader.spirv()?;
self.domain = Some((data.to_vec(), shader.entry().to_string()));
Ok(self)
}
#[inline(always)]
pub fn with_compute<S: Shader>(mut self, shader: &S) -> Result<Self, failure::Error> {
let data = shader.spirv()?;
self.compute = Some((data.to_vec(), shader.entry().to_string()));
Ok(self)
}
#[cfg(feature = "spirv-reflection")]
pub fn reflect(&self) -> Result<SpirvReflection, failure::Error> {
if self.vertex.is_none() && self.compute.is_none() {
failure::bail!("A vertex or compute shader must be provided");
}
let mut reflections = Vec::new();
if let Some(vertex) = self.vertex.as_ref() {
reflections.push(SpirvReflection::reflect(&vertex.0, None)?);
}
if let Some(fragment) = self.fragment.as_ref() {
reflections.push(SpirvReflection::reflect(&fragment.0, None)?);
}
if let Some(hull) = self.hull.as_ref() {
reflections.push(SpirvReflection::reflect(&hull.0, None)?);
}
if let Some(domain) = self.domain.as_ref() {
reflections.push(SpirvReflection::reflect(&domain.0, None)?);
}
if let Some(compute) = self.compute.as_ref() {
reflections.push(SpirvReflection::reflect(&compute.0, None)?);
}
if let Some(geometry) = self.geometry.as_ref() {
reflections.push(SpirvReflection::reflect(&geometry.0, None)?);
}
reflect::merge(&reflections)?.compile_cache()
}
}
#[derive(Debug)]
pub struct ShaderStorage<B: Backend> {
stage: ShaderStageFlags,
spirv: Vec<u8>,
module: Option<B::ShaderModule>,
entrypoint: String,
specialization: Option<gfx_hal::pso::Specialization<'static>>,
}
impl<B: Backend> ShaderStorage<B> {
pub fn get_entry_point<'a>(
&'a self,
) -> Result<Option<gfx_hal::pso::EntryPoint<'a, B>>, failure::Error> {
Ok(Some(gfx_hal::pso::EntryPoint {
entry: &self.entrypoint,
module: self.module.as_ref().unwrap(),
specialization: self.specialization.clone().unwrap_or(gfx_hal::pso::Specialization::default()),
}))
}
pub unsafe fn compile(
&mut self,
factory: &rendy_factory::Factory<B>,
) -> Result<(), failure::Error> {
self.module = Some(gfx_hal::Device::create_shader_module(
factory.device().raw(),
&self.spirv,
)?);
Ok(())
}
fn dispose(&mut self, factory: &rendy_factory::Factory<B>) {
use gfx_hal::device::Device;
if let Some(module) = self.module.take() {
unsafe { factory.destroy_shader_module(module) };
}
self.module = None;
}
}
impl<B: Backend> Drop for ShaderStorage<B> {
fn drop(&mut self) {
if self.module.is_some() {
panic!("This shader storage class needs to be manually dropped with dispose() first");
}
}
}