Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import sttp.tapir.codegen.openapi.models.GenerationDirectives.{
securityPrefixKey
}
import sttp.tapir.codegen.util.ErrUtils.bail
import sttp.tapir.codegen.util.{JavaEscape, Location}
import sttp.tapir.codegen.util.{JavaEscape, Location, NameHelpers}

case class EndpointTypes(security: Seq[String], in: Seq[String], err: Seq[String], out: Seq[String]) {
private def toType(types: Seq[String]) = types match {
Expand Down Expand Up @@ -282,7 +282,10 @@ class EndpointGenerator {

val maybeTargetFileName = if (useHeadTagForObjectNames) m.tags.flatMap(_.headOption) else None
val queryOrPathParamRefs = m.resolvedParameters
.collect { case queryParam: OpenapiParameter if queryParam.in == "query" || queryParam.in == "path" => queryParam.schema }
.collect {
case queryParam: OpenapiParameter if queryParam.in == "query" || queryParam.in == "path" || queryParam.in == "header" =>
queryParam.schema
}
.collect {
case ref: OpenapiSchemaRef if ref.isSchema => ref.stripped
case OpenapiSchemaArray(ref: OpenapiSchemaRef, _, _, _) if ref.isSchema => ref.stripped
Expand Down Expand Up @@ -611,7 +614,7 @@ class EndpointGenerator {
atts
.map { case (k, v) =>
val camelCaseK = strippedToCamelCase(k)
val uncapitalisedName = camelCaseK.head.toLower + camelCaseK.tail
val uncapitalisedName = camelCaseK.head.toLower +: camelCaseK.tail
s""".attribute[${camelCaseK.capitalize}Extension](${uncapitalisedName}ExtensionKey, ${SpecificationExtensionRenderer.renderValue(
v
)})"""
Expand Down Expand Up @@ -758,17 +761,13 @@ class EndpointGenerator {
)
}
}
def mappedGroup(group: Seq[OpenapiResponseDef], isErrorPosition: Boolean): (Option[String], Option[String], Option[String]) =
def mappedGroup(group: Seq[OpenapiResponseDef], isErrorPosition: Boolean): MappedOutGroup =
group match {
case Nil => (None, None, None)
case Nil => MappedOutGroup(None, None, None)
case resp +: Nil =>
val (outHeaderDefns, outHeaderInlineEnums, outHeaderTypes) = resp.headers
// according to api spec, content-type header should be ignored - cf https://swagger.io/specification/#response-object
.filterNot(_._1.toLowerCase == "content-type")
.map { case (name, defn) =>
genParamDefn(endpointName, targetScala3, jsonSerdeLib, defn.resolved(name, doc).param, doc, generateValidators)
}
.unzip3
val (outHeaderDefns, outHeaderInlineEnums, outHeaderTypes) = resp.getHeaders.map { case (name, defn) =>
genParamDefn(endpointName, targetScala3, jsonSerdeLib, defn.resolved(name, doc).param, doc, generateValidators)
}.unzip3
val hs = outHeaderDefns.map(d => s".and($d)").mkString
def ht(wrap: Boolean = true) =
if (outHeaderTypes.isEmpty) None
Expand All @@ -782,7 +781,7 @@ class EndpointGenerator {
resp.content match {
case Nil =>
val d = s""".description("${JavaEscape.escapeString(resp.description)}")"""
(
MappedOutGroup(
resp.code match {
case "200" | "default" if outHeaderDefns.isEmpty => None
case "200" => Some(s"statusCode(sttp.model.StatusCode(200))$d$hs")
Expand All @@ -793,19 +792,20 @@ class EndpointGenerator {
ht(),
inlineHeaderEnumDefns
)
case s =>
case _ =>
val (decl, maybeBodyType, inlineDefn) = bodyFmt(resp, isErrorPosition)
val tpe =
if (outHeaderTypes.isEmpty) maybeBodyType
else if (maybeBodyType.isEmpty) ht()
else maybeBodyType.map(t => s"($t, ${ht(false).get})")
val tpeIsBin = maybeBodyType.exists(t => t.contains("BinaryStream") || t.contains("fs2.Stream"))
(
MappedOutGroup(
Some(resp.code match {
case "200" | "default" => s"$decl$hs"
case okStatus(s) if tpeIsBin => s"$decl.toEndpointIO$hs.and(statusCode(sttp.model.StatusCode($s)))"
case okStatus(s) => s"$decl$hs.and(statusCode(sttp.model.StatusCode($s)))"
case errorStatus(s) => s"$decl$hs.and(statusCode(sttp.model.StatusCode($s)))"
case "200" | "default" if !tpeIsBin || hs.isEmpty => s"$decl$hs"
case "200" | "default" => s"$decl.toEndpointIO$hs"
case okStatus(s) if tpeIsBin => s"$decl.toEndpointIO$hs.and(statusCode(sttp.model.StatusCode($s)))"
case okStatus(s) => s"$decl$hs.and(statusCode(sttp.model.StatusCode($s)))"
case errorStatus(s) => s"$decl$hs.and(statusCode(sttp.model.StatusCode($s)))"
}),
tpe,
inlineDefn.map(_ ++ inlineHeaderEnumDefns.getOrElse("")).orElse(inlineHeaderEnumDefns)
Expand All @@ -815,9 +815,9 @@ class EndpointGenerator {
if (many.map(_.code).distinct.size != many.size) bail("Cannot construct schema for multiple responses with same status code")
val contentCanBeEmpty = many.exists(_.content.isEmpty)
val allBodiesAreEmpty = many.forall(_.content.isEmpty)
val allResponsesAreEmpty = allBodiesAreEmpty && many.forall(_.headers.isEmpty)
val allResponsesAreEmpty = allBodiesAreEmpty && many.forall(_.getHeaders.isEmpty)
val (noHeaders, hs, outHeaderDefns, matchHeaders, headerTypes, headerTopType) =
headerDefns(targetScala3, jsonSerdeLib, doc, generateValidators)(endpointName, many, isErrorPosition)
headerDefns(targetScala3, jsonSerdeLib, doc, generateValidators)(endpointName, many)
val (oneOfs, types, inlineDefns) = many.map { m =>
val (decl, maybeBodyType, inlineDefn1) = bodyFmt(m, isErrorPosition, optional = contentCanBeEmpty)
val code = if (m.code == "default") "400" else m.code
Expand Down Expand Up @@ -927,40 +927,38 @@ class EndpointGenerator {
}
}
val oneOfType = if (noHeaders) commmonType else if (allBodiesAreEmpty) headerTopType else s"($commmonType, $headerTopType)"
(
MappedOutGroup(
Some(s"oneOf[$oneOfType](${oneOfs.mkString("\n ", ",\n ", "")})"),
Some(oneOfType),
(inlineDefns ++ outHeaderDefns).foldLeft(Option.empty[String])(combine(_, _))
)
}

val (outDecls, outTypes, inlineOutDefns) = mappedGroup(outs, false)
val MappedOutGroup(outDecls, outTypes, inlineOutDefns) = mappedGroup(outs, false)
val mappedOuts = outDecls.map(s => s".out($s)")
val (errDecls, errTypes, inlineErrDefns) = mappedGroup(errorOuts, true)
val MappedOutGroup(errDecls, errTypes, inlineErrDefns) = mappedGroup(errorOuts, true)
val mappedErrorOuts = errDecls.map(s => s".errorOut($s)")

(Seq(mappedErrorOuts, mappedOuts).flatten.mkString("\n"), outTypes, errTypes, combine(inlineOutDefns, inlineErrDefns))
}

private def headerDefns(targetScala3: Boolean, jsonSerdeLib: JsonSerdeLib, doc: OpenapiDocument, generateValidators: Boolean)(
endpointName: String,
many: Seq[OpenapiResponseDef],
isErrorPosition: Boolean
many: Seq[OpenapiResponseDef]
)(implicit
location: Location
): (Boolean, OpenapiResponseDef => String, Seq[Option[String]], OpenapiResponseDef => String, OpenapiResponseDef => String, String) = {
val (paramNames, headerNamesAndTypes) = many.map { m =>
m.headers
.filterNot(_._1.toLowerCase == "content-type")
m.getHeaders
.map { case (name, defn) =>
val param = defn.resolved(name, doc).param
param.name -> genParamDefn(endpointName, targetScala3, jsonSerdeLib, param, doc, generateValidators)
NameHelpers.safeVariableName(param.name) ->
genParamDefn(endpointName, targetScala3, jsonSerdeLib, param, doc, generateValidators)
}
.toSeq
.sortBy(_._1)
.unzip
}.unzip
val posn = if (isErrorPosition) "error" else "output"
if (headerNamesAndTypes.forall(_.isEmpty)) (true, _ => "", Nil, _ => "", _ => "", "")
else if (headerNamesAndTypes.map(_.map { case (name, _, defn) => name -> defn }.toSet).distinct.size == 1) {
val commonResponseHeaders = headerNamesAndTypes.head
Expand All @@ -969,7 +967,7 @@ class EndpointGenerator {
val hs = (_: OpenapiResponseDef) => outHeaderDefns.zipWithIndex.map { case (d, 0) => d; case (d, _) => s".and($d)" }.mkString
val noHeaders = commonResponseHeaders.isEmpty

def ht(m: OpenapiResponseDef) =
val ht = (_: OpenapiResponseDef) =>
if (outHeaderTypes.isEmpty) bail("Should not try to construct header types if no headers are required")
else if (outHeaderTypes.size == 1) outHeaderTypes.head
else s"(${outHeaderTypes.mkString(", ")})"
Expand Down Expand Up @@ -1014,7 +1012,7 @@ class EndpointGenerator {
def ht(m: OpenapiResponseDef) = tpesByCode(m.code)
def getMapping(m: OpenapiResponseDef) = mappingsByCode(m.code)
def getMatch(m: OpenapiResponseDef) =
if (m.headers.forall(_._1.toLowerCase == "content-type")) s"$traitName${m.code}"
if (m.getHeaders.isEmpty) s"$traitName${m.code}"
else s"(_: $traitName${m.code})"
val enums = enumDefns.flatten.distinct.map(Some(_))
(false, getMapping, Some(headerTypeDefns) +: enums, getMatch, ht, traitName)
Expand Down Expand Up @@ -1204,3 +1202,4 @@ class EndpointGenerator {
}

case class MappedContentType(bodyImpl: String, bodyType: String, inlineDefns: Option[String] = None)
case class MappedOutGroup(decls: Option[String], types: Option[String], defns: Option[String])
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ object ServersGenerator {
val (withDefault, withoutDefault) = enumNames.partition(_._3.isDefined)
val enumParams = (withoutDefault ++ withDefault)
.map {
case (e, vs, d) if vs.nonEmpty => s"_$e: $e${d.map(v => s" = $e.default").getOrElse("")}"
case (e, vs, d) if vs.nonEmpty => s"_$e: $e${d.map(_ => s" = $e.default").getOrElse("")}"
case (e, _, d) if d.nonEmpty => s"_$e: String${d.map(_ => s" = ${safeVariableName(s"default${e.capitalize}")}").getOrElse("")}"
case (e, _, _) => s"_$e: String"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaSimpleType,
OpenapiSchemaString
}
import sttp.tapir.codegen.util.{DocUtils, JavaEscape}
import sttp.tapir.codegen.util.JavaEscape

import scala.annotation.tailrec

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import OpenapiSchemaType.{
OpenapiSchemaObject,
OpenapiSchemaRef,
OpenapiSchemaRefDecoder,
OpenapiSchemaSimpleType
OpenapiSchemaString
}
import io.circe.Json
import sttp.tapir.codegen.RootGenerator.strippedToCamelCase
Expand All @@ -20,7 +20,6 @@ import sttp.tapir.codegen.openapi.models.GenerationDirectives.{
forceRespStreaming,
forceStreaming
}
import sttp.tapir.codegen.util.MapUtils

import scala.collection.mutable
// https://swagger.io/specification/
Expand Down Expand Up @@ -204,9 +203,28 @@ object OpenapiModels {
code: String,
description: String,
content: Seq[OpenapiResponseContent],
headers: Map[String, OpenapiHeader] = Map.empty
private val headers: Map[String, OpenapiHeader] = Map.empty
) extends OpenapiResponse {
def resolve(doc: OpenapiDocument): OpenapiResponseDef = this
private def maybeContentTypeHeader: Option[(String, OpenapiHeader)] = if (content.forall(!_.contentType.contains("*"))) None
else {
def generalRegex = "([^*]+|[*])/([^*]+|[*])"
val validatingRegex = content.map(_.contentType) match {
Copy link
Copy Markdown
Contributor Author

@hughsimpson hughsimpson Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe having the regex at all is overkill here... For an end endpoint that declared both e.g. image/* and application/* responses on the same status code it wouldn't ensure that you hadn't mixed the two responses up anyway. Still, it seemed like a good idea at the time - that case is fairly niche, and at least it prevents the more egregious possible errors.

case s if s.contains("*/*") => generalRegex
case Seq(oneType) => oneType.replace("*", "([^*]+|[*])")
case s => s.map(_.replace("*", "([^*]+|[*])")).map(t => s"($t)").mkString("|")
}
Some(
"Content-Type" -> OpenapiHeaderDef(
OpenapiParameter("Content-Type", "header", Some(true), None, OpenapiSchemaString(false, Some(validatingRegex), None, None))
)
)
}
def getHeaders: Map[String, OpenapiHeader] =
// according to api spec, content-type header should be ignored - cf https://swagger.io/specification/#response-object
headers.filterNot(_._1.toLowerCase == "content-type") ++
// but add one explicitly if content contains a wildcard
maybeContentTypeHeader
}
case class OpenapiResponseRef(
code: String,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package sttp.tapir.codegen.openapi.models

import sttp.tapir.codegen.openapi.models.OpenapiModels.{OpenapiHeader, OpenapiRequestBodyContent, OpenapiResponseContent}
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{OpenapiSchemaRef, OpenapiSchemaSimpleType}

case class OpenapiResponseDefn(
description: String,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package sttp.tapir.codegen.openapi.models

import io.circe.generic.semiauto.deriveDecoder
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.OpenapiSchemaEnum

case class OpenapiServerEnum(`enum`: Seq[String] = Nil, default: Option[String] = None)
object OpenapiServerEnum {
import io.circe._
Expand Down
Loading
Loading