use proc_macro2::TokenStream;
use quote::quote;
use syn::{Ident, Path, Result, Type};
use petgraph::{visit::EdgeRef, Direction};
use super::*;
pub(crate) fn impl_subsystem_types_all(info: &OrchestraInfo) -> Result<TokenStream> {
let mut ts = TokenStream::new();
let orchestra_name = &info.orchestra_name;
let span = orchestra_name.span();
let all_messages_wrapper = &info.message_wrapper;
let support_crate = info.support_crate_name();
let signal_ty = &info.extern_signal_ty;
let error_ty = &info.extern_error_ty;
let cg = graph::ConnectionGraph::construct(info.subsystems());
let graph = &cg.graph;
for node_index in graph.node_indices() {
let subsystem_name = graph[node_index].to_string();
let outgoing_wrapper = Ident::new(&(subsystem_name + "OutgoingMessages"), span);
let outgoing_to_consumer = graph
.edges_directed(node_index, Direction::Outgoing)
.map(|edge| {
let message_ty = edge.weight();
let subsystem_generic_consumer = graph[edge.target()].clone();
Ok((to_variant(message_ty, span.clone())?, subsystem_generic_consumer))
})
.collect::<Result<Vec<(Ident, Ident)>>>()?;
let outgoing_variant = outgoing_to_consumer.iter().map(|x| x.0.clone()).collect::<Vec<_>>();
let subsystem_generic = outgoing_to_consumer.into_iter().map(|x| x.1).collect::<Vec<_>>();
ts.extend(quote! {
impl ::std::convert::From< #outgoing_wrapper > for #all_messages_wrapper {
fn from(message: #outgoing_wrapper) -> Self {
match message {
#(
#outgoing_wrapper :: #outgoing_variant ( msg ) => #all_messages_wrapper :: #subsystem_generic ( msg ),
)*
#outgoing_wrapper :: Empty => #all_messages_wrapper :: Empty,
#[allow(unreachable_patterns)]
unused_msg => {
#support_crate :: tracing :: warn!("Nothing consumes {:?}", unused_msg);
#all_messages_wrapper :: Empty
}
}
}
}
})
}
#[cfg(feature = "graph")]
{
let path = std::path::PathBuf::from(env!("OUT_DIR"))
.join(orchestra_name.to_string().to_lowercase() + "-subsystem-messaging.dot");
if let Err(e) = std::fs::OpenOptions::new()
.truncate(true)
.create(true)
.write(true)
.open(&path)
.and_then(|mut f| cg.graphviz(&mut f))
{
eprintln!("Failed to write dot graph to {}: {:?}", path.display(), e);
} else {
println!("Wrote dot graph to {}", path.display());
}
}
let subsystem_sender_name = &Ident::new(&(orchestra_name.to_string() + "Sender"), span);
let subsystem_ctx_name = &Ident::new(&(orchestra_name.to_string() + "SubsystemContext"), span);
ts.extend(impl_subsystem_context(info, &subsystem_sender_name, &subsystem_ctx_name));
ts.extend(impl_associate_outgoing_messages_trait(&all_messages_wrapper));
ts.extend(impl_subsystem_sender(
support_crate,
info.subsystems().iter().map(|ssf| {
let outgoing_wrapper =
Ident::new(&(ssf.generic.to_string() + "OutgoingMessages"), span);
outgoing_wrapper
}),
&all_messages_wrapper,
&subsystem_sender_name,
));
for ssf in info.subsystems() {
let subsystem_name = ssf.generic.to_string();
let outgoing_wrapper = &Ident::new(&(subsystem_name.clone() + "OutgoingMessages"), span);
let message_to_consume = ssf.message_to_consume();
let subsystem_ctx_trait = &Ident::new(&(subsystem_name.clone() + "ContextTrait"), span);
let subsystem_sender_trait = &Ident::new(&(subsystem_name.clone() + "SenderTrait"), span);
ts.extend(impl_per_subsystem_helper_traits(
info,
subsystem_ctx_name,
subsystem_ctx_trait,
subsystem_sender_name,
subsystem_sender_trait,
&message_to_consume,
&ssf.messages_to_send,
outgoing_wrapper,
));
ts.extend(impl_associate_outgoing_messages(&message_to_consume, &outgoing_wrapper));
ts.extend(impl_wrapper_enum(&outgoing_wrapper, ssf.messages_to_send.as_slice())?);
}
ts.extend({
let mut messages = TokenStream::new();
for ssf in info.subsystems() {
messages.extend(ssf.gen_dummy_message_ty());
}
let comment = "The exclusive home of all generated dummy messages (if any at all)";
quote! {
#[doc = #comment]
pub mod messages {
#messages
}
}
});
let empty_tuple: Type = parse_quote! { () };
ts.extend(impl_subsystem_context_trait_for(
empty_tuple.clone(),
&[],
empty_tuple,
all_messages_wrapper,
subsystem_ctx_name,
subsystem_sender_name,
support_crate,
signal_ty,
error_ty,
));
Ok(ts)
}
fn to_variant(path: &Path, span: Span) -> Result<Ident> {
let ident = path
.segments
.last()
.ok_or_else(|| syn::Error::new(span, "Path is empty, but it must end with an identifier"))
.map(|segment| segment.ident.clone())?;
Ok(ident)
}
fn to_variants(message_types: &[Path], span: Span) -> Result<Vec<Ident>> {
let variants: Vec<_> =
Result::from_iter(message_types.into_iter().map(|path| to_variant(path, span.clone())))?;
Ok(variants)
}
pub(crate) fn impl_wrapper_enum(wrapper: &Ident, message_types: &[Path]) -> Result<TokenStream> {
let variants = to_variants(message_types, wrapper.span())?;
let ts = quote! {
#[allow(missing_docs)]
#[derive(Debug)]
pub enum #wrapper {
#(
#variants ( #message_types ),
)*
Empty,
}
#(
impl ::std::convert::From< #message_types > for #wrapper {
fn from(message: #message_types) -> Self {
#wrapper :: #variants ( message )
}
}
)*
impl ::std::convert::From< () > for #wrapper {
fn from(_message: ()) -> Self {
#wrapper :: Empty
}
}
};
Ok(ts)
}
pub(crate) fn impl_subsystem_sender(
support_crate: &Path,
outgoing_wrappers: impl IntoIterator<Item = Ident>,
all_messages_wrapper: &Ident,
subsystem_sender_name: &Ident,
) -> TokenStream {
let mut ts = quote! {
#[derive(Debug)]
pub struct #subsystem_sender_name < OutgoingWrapper > {
channels: ChannelsOut,
signals_received: SignalsReceived,
_phantom: ::core::marker::PhantomData< OutgoingWrapper >,
}
impl<OutgoingWrapper> std::clone::Clone for #subsystem_sender_name < OutgoingWrapper > {
fn clone(&self) -> Self {
Self {
channels: self.channels.clone(),
signals_received: self.signals_received.clone(),
_phantom: ::core::marker::PhantomData::default(),
}
}
}
};
let wrapped = |outgoing_wrapper: &TokenStream| {
quote! {
#[#support_crate ::async_trait]
impl<OutgoingMessage> SubsystemSender< OutgoingMessage > for #subsystem_sender_name < #outgoing_wrapper >
where
OutgoingMessage: Send + 'static,
#outgoing_wrapper: ::std::convert::From<OutgoingMessage> + Send,
#all_messages_wrapper: ::std::convert::From< #outgoing_wrapper > + Send,
{
async fn send_message(&mut self, msg: OutgoingMessage)
{
self.channels.send_and_log_error(
self.signals_received.load(),
<#all_messages_wrapper as ::std::convert::From<_>> ::from (
<#outgoing_wrapper as ::std::convert::From<_>> :: from ( msg )
)
).await;
}
async fn send_messages<I>(&mut self, msgs: I)
where
I: IntoIterator<Item=OutgoingMessage> + Send,
I::IntoIter: Iterator<Item=OutgoingMessage> + Send,
{
for msg in msgs {
self.send_message( msg ).await;
}
}
fn send_unbounded_message(&mut self, msg: OutgoingMessage)
{
self.channels.send_unbounded_and_log_error(
self.signals_received.load(),
<#all_messages_wrapper as ::std::convert::From<_>> ::from (
<#outgoing_wrapper as ::std::convert::From<_>> :: from ( msg )
)
);
}
}
}
};
for outgoing_wrapper in outgoing_wrappers {
ts.extend(wrapped("e! {
#outgoing_wrapper
}));
}
ts.extend(wrapped("e! {
()
}));
ts
}
pub(crate) fn impl_associate_outgoing_messages_trait(all_messages_wrapper: &Ident) -> TokenStream {
quote! {
pub trait AssociateOutgoing: ::std::fmt::Debug + Send {
type OutgoingMessages: Into< #all_messages_wrapper > + ::std::fmt::Debug + Send;
}
impl AssociateOutgoing for () {
type OutgoingMessages = ();
}
impl AssociateOutgoing for #all_messages_wrapper {
type OutgoingMessages = #all_messages_wrapper ;
}
}
}
pub(crate) fn impl_associate_outgoing_messages(
consumes: &Path,
outgoing_wrapper: &Ident,
) -> TokenStream {
quote! {
impl AssociateOutgoing for #outgoing_wrapper {
type OutgoingMessages = #outgoing_wrapper;
}
impl AssociateOutgoing for #consumes {
type OutgoingMessages = #outgoing_wrapper;
}
}
}
pub(crate) fn impl_subsystem_context_trait_for(
consumes: Type,
outgoing: &[Type],
outgoing_wrapper: Type,
all_messages_wrapper: &Ident,
subsystem_ctx_name: &Ident,
subsystem_sender_name: &Ident,
support_crate: &Path,
signal: &Path,
error_ty: &Path,
) -> TokenStream {
let where_clause = quote! {
#consumes: AssociateOutgoing + ::std::fmt::Debug + Send + 'static,
#all_messages_wrapper: From< #outgoing_wrapper >,
#all_messages_wrapper: From< #consumes >,
#outgoing_wrapper: #( From< #outgoing > )+*,
};
quote! {
#[#support_crate ::async_trait]
impl #support_crate ::SubsystemContext for #subsystem_ctx_name < #consumes >
where
#where_clause
{
type Message = #consumes;
type Signal = #signal;
type OutgoingMessages = #outgoing_wrapper;
type Sender = #subsystem_sender_name < #outgoing_wrapper >;
type Error = #error_ty;
async fn try_recv(&mut self) -> ::std::result::Result<Option<FromOrchestra< Self::Message, #signal>>, ()> {
match #support_crate ::poll!(self.recv()) {
#support_crate ::Poll::Ready(msg) => Ok(Some(msg.map_err(|_| ())?)),
#support_crate ::Poll::Pending => Ok(None),
}
}
async fn recv(&mut self) -> ::std::result::Result<FromOrchestra<Self::Message, #signal>, #error_ty> {
loop {
if let Some((needs_signals_received, msg)) = self.pending_incoming.take() {
if needs_signals_received <= self.signals_received.load() {
return Ok( #support_crate ::FromOrchestra::Communication { msg });
} else {
self.pending_incoming = Some((needs_signals_received, msg));
let signal = self.signals.next().await
.ok_or(#support_crate ::OrchestraError::Context(
"Signal channel is terminated and empty."
.to_owned()
))?;
self.signals_received.inc();
return Ok( #support_crate ::FromOrchestra::Signal(signal))
}
}
let mut await_message = self.messages.next().fuse();
let mut await_signal = self.signals.next().fuse();
let signals_received = self.signals_received.load();
let pending_incoming = &mut self.pending_incoming;
let from_orchestra = #support_crate ::futures::select_biased! {
signal = await_signal => {
let signal = signal
.ok_or( #support_crate ::OrchestraError::Context(
"Signal channel is terminated and empty."
.to_owned()
))?;
#support_crate ::FromOrchestra::Signal(signal)
}
msg = await_message => {
let packet = msg
.ok_or( #support_crate ::OrchestraError::Context(
"Message channel is terminated and empty."
.to_owned()
))?;
if packet.signals_received > signals_received {
*pending_incoming = Some((packet.signals_received, packet.message));
continue;
} else {
#support_crate ::FromOrchestra::Communication { msg: packet.message}
}
}
};
if let #support_crate ::FromOrchestra::Signal(_) = from_orchestra {
self.signals_received.inc();
}
return Ok(from_orchestra);
}
}
fn sender(&mut self) -> &mut Self::Sender {
&mut self.to_subsystems
}
fn spawn(&mut self, name: &'static str, s: Pin<Box<dyn Future<Output = ()> + Send>>)
-> ::std::result::Result<(), #error_ty>
{
self.to_orchestra.unbounded_send(#support_crate ::ToOrchestra::SpawnJob {
name,
subsystem: Some(self.name()),
s,
}).map_err(|_| #support_crate ::OrchestraError::TaskSpawn(name))?;
Ok(())
}
fn spawn_blocking(&mut self, name: &'static str, s: Pin<Box<dyn Future<Output = ()> + Send>>)
-> ::std::result::Result<(), #error_ty>
{
self.to_orchestra.unbounded_send(#support_crate ::ToOrchestra::SpawnBlockingJob {
name,
subsystem: Some(self.name()),
s,
}).map_err(|_| #support_crate ::OrchestraError::TaskSpawn(name))?;
Ok(())
}
}
}
}
pub(crate) fn impl_per_subsystem_helper_traits(
info: &OrchestraInfo,
subsystem_ctx_name: &Ident,
subsystem_ctx_trait: &Ident,
subsystem_sender_name: &Ident,
subsystem_sender_trait: &Ident,
consumes: &Path,
outgoing: &[Path],
outgoing_wrapper: &Ident,
) -> TokenStream {
let all_messages_wrapper = &info.message_wrapper;
let signal_ty = &info.extern_signal_ty;
let error_ty = &info.extern_error_ty;
let support_crate = info.support_crate_name();
let mut ts = TokenStream::new();
let acc_sender_trait_bounds = quote! {
#support_crate ::SubsystemSender< #outgoing_wrapper >
#(
+ #support_crate ::SubsystemSender< #outgoing >
)*
+ #support_crate ::SubsystemSender< () >
+ Send
+ 'static
};
ts.extend(quote! {
pub trait #subsystem_sender_trait : #acc_sender_trait_bounds
{}
impl<T> #subsystem_sender_trait for T
where
T: #acc_sender_trait_bounds
{}
});
let where_clause = quote! {
#consumes: AssociateOutgoing + ::std::fmt::Debug + Send + 'static,
#all_messages_wrapper: From< #outgoing_wrapper >,
#all_messages_wrapper: From< #consumes >,
#all_messages_wrapper: From< () >,
#outgoing_wrapper: #( From< #outgoing > )+*,
#outgoing_wrapper: From< () >,
};
ts.extend(quote! {
pub trait #subsystem_ctx_trait : SubsystemContext <
Message = #consumes,
Signal = #signal_ty,
OutgoingMessages = #outgoing_wrapper,
Error = #error_ty,
>
where
#where_clause
<Self as SubsystemContext>::Sender:
#subsystem_sender_trait
+ #acc_sender_trait_bounds,
{
type Sender: #subsystem_sender_trait;
}
impl<T> #subsystem_ctx_trait for T
where
T: SubsystemContext <
Message = #consumes,
Signal = #signal_ty,
OutgoingMessages = #outgoing_wrapper,
Error = #error_ty,
>,
#where_clause
<T as SubsystemContext>::Sender:
#subsystem_sender_trait
+ #acc_sender_trait_bounds,
{
type Sender = <T as SubsystemContext>::Sender;
}
});
ts.extend(impl_subsystem_context_trait_for(
parse_quote! { #consumes },
&Vec::from_iter(outgoing.iter().map(|path| {
parse_quote! { #path }
})),
parse_quote! { #outgoing_wrapper },
all_messages_wrapper,
subsystem_ctx_name,
subsystem_sender_name,
support_crate,
signal_ty,
error_ty,
));
ts
}
pub(crate) fn impl_subsystem_context(
info: &OrchestraInfo,
subsystem_sender_name: &Ident,
subsystem_ctx_name: &Ident,
) -> TokenStream {
let signal_ty = &info.extern_signal_ty;
let support_crate = info.support_crate_name();
let ts = quote! {
#[derive(Debug)]
#[allow(missing_docs)]
pub struct #subsystem_ctx_name<M: AssociateOutgoing + Send + 'static> {
signals: #support_crate ::metered::MeteredReceiver< #signal_ty >,
messages: SubsystemIncomingMessages< M >,
to_subsystems: #subsystem_sender_name < <M as AssociateOutgoing>::OutgoingMessages >,
to_orchestra: #support_crate ::metered::UnboundedMeteredSender<
#support_crate ::ToOrchestra
>,
signals_received: SignalsReceived,
pending_incoming: Option<(usize, M)>,
name: &'static str
}
impl<M> #subsystem_ctx_name <M>
where
M: AssociateOutgoing + Send + 'static,
{
fn new(
signals: #support_crate ::metered::MeteredReceiver< #signal_ty >,
messages: SubsystemIncomingMessages< M >,
to_subsystems: ChannelsOut,
to_orchestra: #support_crate ::metered::UnboundedMeteredSender<#support_crate:: ToOrchestra>,
name: &'static str
) -> Self {
let signals_received = SignalsReceived::default();
#subsystem_ctx_name :: <M> {
signals,
messages,
to_subsystems: #subsystem_sender_name :: < <M as AssociateOutgoing>::OutgoingMessages > {
channels: to_subsystems,
signals_received: signals_received.clone(),
_phantom: ::core::marker::PhantomData::default(),
},
to_orchestra,
signals_received,
pending_incoming: None,
name
}
}
fn name(&self) -> &'static str {
self.name
}
}
};
ts
}