Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 93 additions & 12 deletions crates/go/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use anyhow::Result;
use heck::{ToLowerCamelCase as _, ToSnakeCase as _, ToUpperCamelCase as _};
use std::borrow::Cow;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, hash_map};
use std::fmt;
use std::fmt::Write as _;
use std::io::{self, Write as _};
use std::iter;
use std::mem;
use std::process::Command;
use std::str::FromStr;
use std::thread;
use wit_bindgen_core::abi::{
self, AbiVariant, Bindgen, Bitcast, FlatTypes, Instruction, LiftLower, WasmType,
};
Expand All @@ -27,9 +33,55 @@ const EXPORT_RETURN_AREA: &str = "exportReturnArea";
const SYNC_EXPORT_PINNER: &str = "syncExportPinner";
const PINNER: &str = "pinner";

#[derive(Default, Debug, Copy, Clone)]
pub enum Format {
#[default]
True,
False,
}

impl fmt::Display for Format {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
match self {
Self::True => "true",
Self::False => "false",
}
)
}
}

impl FromStr for Format {
type Err = String;

fn from_str(s: &str) -> Result<Format, String> {
match s {
"true" => Ok(Format::True),
"false" => Ok(Format::False),
_ => Err(format!("expected `true` or `false`; got `{s}`")),
}
}
}

#[derive(Default, Debug, Clone)]
#[cfg_attr(feature = "clap", derive(clap::Parser))]
pub struct Opts {
/// Whether or not `gofmt` should be used (if present) to format generated
/// code.
#[cfg_attr(
feature = "clap",
arg(
long,
default_value = "true",
default_missing_value = "true",
num_args = 0..=1,
require_equals = true,
)
)]
pub format: Format,

#[cfg_attr(feature = "clap", clap(flatten))]
pub async_: AsyncFilterSet,

Expand Down Expand Up @@ -717,8 +769,10 @@ impl WorldGenerator for Go {

files.push(
"wit_bindings.go",
format!(
r#"package main
&maybe_gofmt(
self.opts.format,
format!(
r#"package main

import (
"runtime"
Expand All @@ -734,8 +788,9 @@ var {SYNC_EXPORT_PINNER} = runtime.Pinner{{}}
// Unused, but present to make the compiler happy
func main() {{}}
"#
)
.as_bytes(),
)
.as_bytes(),
),
);
files.push("go.mod", b"module wit_component\n\ngo 1.25");
files.push(
Expand All @@ -750,16 +805,19 @@ func main() {{}}

files.push(
&format!("{prefix}{name}/wit_bindings.go"),
format!(
"package {prefix}{name}
&maybe_gofmt(
self.opts.format,
format!(
"package {prefix}{name}

import (
{imports}
)

{code}"
)
.as_bytes(),
)
.as_bytes(),
),
);
}
}
Expand Down Expand Up @@ -788,13 +846,16 @@ import (

files.push(
"wit_types/wit_tuples.go",
format!(
r#"package wit_types
&maybe_gofmt(
self.opts.format,
format!(
r#"package wit_types

{tuples}
"#
)
.as_bytes(),
)
.as_bytes(),
),
);
}

Expand Down Expand Up @@ -2948,3 +3009,23 @@ fn func_declaration(resolve: &Resolve, func: &Function) -> (String, bool) {
}
}
}

fn maybe_gofmt<'a>(format: Format, code: &'a [u8]) -> Cow<'a, [u8]> {
return thread::scope(|s| {
if let Format::True = format
&& let Ok((reader, mut writer)) = io::pipe()
{
s.spawn(move || {
_ = writer.write_all(&code);
});

if let Ok(output) = Command::new("gofmt").stdin(reader).output()
&& output.status.success()
{
return Cow::Owned(output.stdout);
}
}

Cow::Borrowed(code)
});
}
Loading