diff --git a/generators/go-v2/ast/src/ast/Func.ts b/generators/go-v2/ast/src/ast/Func.ts index aa328a55ed0..a85c9e69e43 100644 --- a/generators/go-v2/ast/src/ast/Func.ts +++ b/generators/go-v2/ast/src/ast/Func.ts @@ -17,10 +17,6 @@ export class Func extends AstNode { this.func = new Method({ name, parameters, return_, body, docs }); } - public get name(): string { - return this.func.name; - } - public get parameters(): Parameter[] { return this.func.parameters; } @@ -29,6 +25,10 @@ export class Func extends AstNode { return this.func.return_; } + public get name(): string | undefined { + return this.func.name; + } + public get body(): CodeBlock | undefined { return this.func.body; } diff --git a/generators/go-v2/ast/src/ast/FuncInvocation.ts b/generators/go-v2/ast/src/ast/FuncInvocation.ts index 7aea423dbd0..11f430bbeb2 100644 --- a/generators/go-v2/ast/src/ast/FuncInvocation.ts +++ b/generators/go-v2/ast/src/ast/FuncInvocation.ts @@ -9,22 +9,26 @@ export declare namespace FuncInvocation { func: GoTypeReference; /* The arguments passed to the method */ arguments_: AstNode[]; + /* Whether to write the invocation on multiple lines */ + multiline?: boolean; } } export class FuncInvocation extends AstNode { private func: GoTypeReference; private arguments_: AstNode[]; + private multiline: boolean | undefined; - constructor({ func, arguments_ }: FuncInvocation.Args) { + constructor({ func, arguments_, multiline = true }: FuncInvocation.Args) { super(); this.func = func; this.arguments_ = arguments_; + this.multiline = multiline; } public write(writer: Writer): void { writer.writeNode(this.func); - writeArguments({ writer, arguments_: this.arguments_ }); + writeArguments({ writer, arguments_: this.arguments_, multiline: this.multiline }); } } diff --git a/generators/go-v2/ast/src/ast/GoTypeReference.ts b/generators/go-v2/ast/src/ast/GoTypeReference.ts index 2330f02e841..6fd3acf1999 100644 --- a/generators/go-v2/ast/src/ast/GoTypeReference.ts +++ b/generators/go-v2/ast/src/ast/GoTypeReference.ts @@ -1,3 +1,4 @@ +import { Type } from "./Type"; import { AstNode } from "./core/AstNode"; import { Writer } from "./core/Writer"; @@ -7,17 +8,21 @@ export declare namespace GoTypeReference { name: string; /* The import path of the Go type */ importPath: string; + /* The generic type parameters, if any */ + generics?: Type[] | undefined; } } export class GoTypeReference extends AstNode { public readonly name: string; public readonly importPath: string; + public readonly generics: Type[] | undefined; - constructor({ name, importPath }: GoTypeReference.Args) { + constructor({ name, importPath, generics }: GoTypeReference.Args) { super(); this.name = name; this.importPath = importPath; + this.generics = generics; } public write(writer: Writer): void { @@ -27,5 +32,17 @@ export class GoTypeReference extends AstNode { } const alias = writer.addImport(this.importPath); writer.write(`${alias}.${this.name}`); + if (this.generics != null) { + writer.write("["); + this.generics.forEach((generic, idx) => { + if (idx > 0) { + writer.write(", "); + } + if (generic != null) { + generic.write(writer); + } + }); + writer.write("]"); + } } } diff --git a/generators/go-v2/ast/src/ast/Method.ts b/generators/go-v2/ast/src/ast/Method.ts index 95b3ddaa1cf..9c085f9894d 100644 --- a/generators/go-v2/ast/src/ast/Method.ts +++ b/generators/go-v2/ast/src/ast/Method.ts @@ -8,12 +8,12 @@ import { Writer } from "./core/Writer"; export declare namespace Method { interface Args { - /* The name of the method */ - name: string; /* The parameters of the method */ parameters: Parameter[]; /* The return type of the method */ return_: Type[]; + /* The name of the method */ + name?: string; /* The body of the method */ body?: CodeBlock; /* Documentation for the method */ @@ -24,9 +24,9 @@ export declare namespace Method { } export class Method extends AstNode { - public readonly name: string; public readonly parameters: Parameter[]; public readonly return_: Type[]; + public readonly name: string | undefined; public readonly body: CodeBlock | undefined; public readonly docs: string | undefined; public readonly typeReference: GoTypeReference | undefined; @@ -43,11 +43,13 @@ export class Method extends AstNode { public write(writer: Writer): void { writer.writeNode(new Comment({ docs: this.docs })); - writer.write("func "); + writer.write("func"); if (this.typeReference != null) { this.writeReceiver({ writer, typeReference: this.typeReference }); } - writer.write(`${this.name}`); + if (this.name != null) { + writer.write(` ${this.name}`); + } if (this.parameters.length === 0) { writer.write("() "); } else { diff --git a/generators/go-v2/ast/src/ast/MethodInvocation.ts b/generators/go-v2/ast/src/ast/MethodInvocation.ts index ef11c896cde..f20a383b27d 100644 --- a/generators/go-v2/ast/src/ast/MethodInvocation.ts +++ b/generators/go-v2/ast/src/ast/MethodInvocation.ts @@ -10,6 +10,8 @@ export declare namespace MethodInvocation { method: string; /* The arguments passed to the method */ arguments_: AstNode[]; + /* Whether to write the invocation on multiple lines */ + multiline?: boolean; } } @@ -17,19 +19,21 @@ export class MethodInvocation extends AstNode { private on: AstNode; private method: string; private arguments_: AstNode[]; + private multiline: boolean | undefined; - constructor({ method, arguments_, on }: MethodInvocation.Args) { + constructor({ method, arguments_, on, multiline }: MethodInvocation.Args) { super(); this.on = on; this.method = method; this.arguments_ = arguments_; + this.multiline = multiline; } public write(writer: Writer): void { this.on.write(writer); writer.write("."); writer.write(this.method); - writeArguments({ writer, arguments_: this.arguments_ }); + writeArguments({ writer, arguments_: this.arguments_, multiline: this.multiline }); } } diff --git a/generators/go-v2/ast/src/ast/Pointer.ts b/generators/go-v2/ast/src/ast/Pointer.ts new file mode 100644 index 00000000000..e75ee517db3 --- /dev/null +++ b/generators/go-v2/ast/src/ast/Pointer.ts @@ -0,0 +1,23 @@ +import { AstNode } from "./core/AstNode"; +import { Writer } from "./core/Writer"; + +export declare namespace Pointer { + interface Args { + /* The value of the pointer */ + node: AstNode; + } +} + +export class Pointer extends AstNode { + public readonly node: AstNode; + + constructor({ node }: Pointer.Args) { + super(); + this.node = node; + } + + public write(writer: Writer): void { + writer.write("*"); + this.node.write(writer); + } +} diff --git a/generators/go-v2/ast/src/ast/Selector.ts b/generators/go-v2/ast/src/ast/Selector.ts new file mode 100644 index 00000000000..00cb6e2ceaa --- /dev/null +++ b/generators/go-v2/ast/src/ast/Selector.ts @@ -0,0 +1,28 @@ +import { AstNode } from "./core/AstNode"; +import { Writer } from "./core/Writer"; + +export declare namespace Selector { + interface Args { + /* The node to select from */ + on: AstNode; + /* The node to select (e.g. a field name) */ + selector: AstNode; + } +} + +export class Selector extends AstNode { + public readonly on: AstNode; + public readonly selector: AstNode; + + constructor({ on, selector }: Selector.Args) { + super(); + this.on = on; + this.selector = selector; + } + + public write(writer: Writer): void { + writer.writeNode(this.on); + writer.write("."); + writer.writeNode(this.selector); + } +} diff --git a/generators/go-v2/ast/src/ast/Struct.ts b/generators/go-v2/ast/src/ast/Struct.ts index 3575ac2c581..8e1a8af631a 100644 --- a/generators/go-v2/ast/src/ast/Struct.ts +++ b/generators/go-v2/ast/src/ast/Struct.ts @@ -1,6 +1,9 @@ +import { CodeBlock } from "@fern-api/browser-compatible-base-generator"; + import { Comment } from "./Comment"; import { Field } from "./Field"; import { Method } from "./Method"; +import { Parameter } from "./Parameter"; import { AstNode } from "./core/AstNode"; import { Writer } from "./core/Writer"; @@ -13,6 +16,11 @@ export declare namespace Struct { /* Docs associated with the class */ docs?: string; } + + interface Constructor { + parameters: Parameter[]; + body: AstNode; + } } export class Struct extends AstNode { @@ -20,6 +28,7 @@ export class Struct extends AstNode { public readonly importPath: string; public readonly docs: string | undefined; + public constructor_: Struct.Constructor | undefined; public readonly fields: Field[] = []; public readonly methods: Method[] = []; @@ -30,12 +39,16 @@ export class Struct extends AstNode { this.docs = docs; } - public addField(field: Field): void { - this.fields.push(field); + public addConstructor(constructor: Struct.Constructor): void { + this.constructor_ = constructor; } - public addMethod(method: Method): void { - this.methods.push(method); + public addField(...fields: Field[]): void { + this.fields.push(...fields); + } + + public addMethod(...methods: Method[]): void { + this.methods.push(...methods); } public write(writer: Writer): void { @@ -54,6 +67,11 @@ export class Struct extends AstNode { writer.writeLine("}"); } + if (this.constructor_ != null) { + writer.newLine(); + this.writeConstructor({ writer, constructor: this.constructor_ }); + } + if (this.methods.length > 0) { writer.newLine(); for (const method of this.methods) { @@ -62,4 +80,21 @@ export class Struct extends AstNode { } } } + + private writeConstructor({ writer, constructor }: { writer: Writer; constructor: Struct.Constructor }): void { + writer.write(`func New${this.name}(`); + constructor.parameters.forEach((parameter, index) => { + if (index > 0) { + writer.write(", "); + } + writer.writeNode(parameter); + }); + writer.write(`) *${this.name} {`); + writer.newLine(); + writer.indent(); + writer.writeNode(constructor.body); + writer.writeNewLineIfLastLineNot(); + writer.dedent(); + writer.writeLine(`}`); + } } diff --git a/generators/go-v2/ast/src/ast/Type.ts b/generators/go-v2/ast/src/ast/Type.ts index b720a1532d3..49671bcb8b5 100644 --- a/generators/go-v2/ast/src/ast/Type.ts +++ b/generators/go-v2/ast/src/ast/Type.ts @@ -11,6 +11,7 @@ type InternalType = | Float64 | Date | DateTime + | Error_ | Int | Int64 | Map @@ -18,7 +19,8 @@ type InternalType = | Reference | Slice | String_ - | Uuid; + | Uuid + | Variadic; interface Any_ { type: "any"; @@ -40,6 +42,10 @@ interface Date { type: "date"; } +interface Error_ { + type: "error"; +} + interface DateTime { type: "dateTime"; } @@ -81,6 +87,11 @@ interface Uuid { type: "uuid"; } +interface Variadic { + type: "variadic"; + value: Type; +} + const NILABLE_TYPES = new Set(["any", "bytes", "map", "slice"]); export class Type extends AstNode { @@ -91,7 +102,7 @@ export class Type extends AstNode { public write(writer: Writer, { comment }: { comment?: boolean } = {}): void { switch (this.internalType.type) { case "any": - writer.write("interface{}"); + writer.write("any"); break; case "bool": writer.write("bool"); @@ -103,6 +114,9 @@ export class Type extends AstNode { case "dateTime": writer.writeNode(TimeTypeReference); break; + case "error": + writer.write("error"); + break; case "float64": writer.write("float64"); break; @@ -137,6 +151,10 @@ export class Type extends AstNode { case "uuid": writer.writeNode(UuidTypeReference); break; + case "variadic": + writer.write("..."); + this.internalType.value.write(writer); + break; default: assertNever(this.internalType); } @@ -177,6 +195,12 @@ export class Type extends AstNode { }); } + public static error(): Type { + return new this({ + type: "error" + }); + } + public static float64(): Type { return new this({ type: "float64" @@ -204,8 +228,8 @@ export class Type extends AstNode { } public static optional(value: Type): Type { - // Avoids double optional. if (this.isAlreadyOptional(value)) { + // Avoids double optional. return value; } return new this({ @@ -245,9 +269,24 @@ export class Type extends AstNode { }); } + public static variadic(value: Type): Type { + if (this.isAlreadyVariadic(value)) { + // Avoids double variadic. + return value; + } + return new this({ + type: "variadic", + value + }); + } + private static isAlreadyOptional(value: Type) { return value.internalType.type === "optional" || NILABLE_TYPES.has(value.internalType.type); } + + private static isAlreadyVariadic(value: Type) { + return value.internalType.type === "variadic"; + } } export const TimeTypeReference = new GoTypeReference({ @@ -259,3 +298,8 @@ export const UuidTypeReference = new GoTypeReference({ importPath: "github.com/google/uuid", name: "UUID" }); + +export const IoReaderTypeReference = new GoTypeReference({ + importPath: "io", + name: "Reader" +}); diff --git a/generators/go-v2/ast/src/ast/TypeInstantiation.ts b/generators/go-v2/ast/src/ast/TypeInstantiation.ts index 163105b958e..a66433a0bbd 100644 --- a/generators/go-v2/ast/src/ast/TypeInstantiation.ts +++ b/generators/go-v2/ast/src/ast/TypeInstantiation.ts @@ -113,6 +113,7 @@ interface Struct { type: "struct"; typeReference: GoTypeReference; fields: StructField[]; + generics?: Type[]; } export interface StructField { @@ -340,17 +341,20 @@ export class TypeInstantiation extends AstNode { public static structPointer({ typeReference, - fields + fields, + generics }: { typeReference: GoTypeReference; fields: StructField[]; + generics?: Type[]; }): TypeInstantiation { return new this({ type: "optional", value: new this({ type: "struct", typeReference, - fields + fields, + generics }) }); } @@ -403,7 +407,7 @@ export class TypeInstantiation extends AstNode { writer: Writer; value: any[]; // eslint-disable-line @typescript-eslint/no-explicit-any }): void { - writer.write("[]interface{}"); + writer.write("[]any"); if (value.length === 0) { writer.write("{}"); return; @@ -419,7 +423,7 @@ export class TypeInstantiation extends AstNode { } private writeAnyObject({ writer, value }: { writer: Writer; value: object }): void { - writer.write("map[string]interface{}"); + writer.write("map[string]any"); const entries = Object.entries(value); if (entries.length === 0) { writer.write("{}"); @@ -510,6 +514,17 @@ export class TypeInstantiation extends AstNode { private writeStruct({ writer, struct }: { writer: Writer; struct: Struct }): void { writer.writeNode(struct.typeReference); + if (struct.generics != null) { + writer.write("["); + struct.generics.forEach((generic, index) => { + if (index > 0) { + writer.write(", "); + } + writer.writeNode(generic); + }); + writer.write("]"); + } + const fields = filterNopStructFields({ fields: struct.fields }); if (fields.length === 0) { writer.write("{}"); diff --git a/generators/go-v2/ast/src/ast/index.ts b/generators/go-v2/ast/src/ast/index.ts index 0df86ad8434..92b31146a2b 100644 --- a/generators/go-v2/ast/src/ast/index.ts +++ b/generators/go-v2/ast/src/ast/index.ts @@ -9,6 +9,8 @@ export { GoTypeReference } from "./GoTypeReference"; export { Method } from "./Method"; export { MethodInvocation } from "./MethodInvocation"; export { Parameter } from "./Parameter"; +export { Pointer } from "./Pointer"; +export { Selector } from "./Selector"; export { Struct } from "./Struct"; export { Type } from "./Type"; export { TypeInstantiation, type StructField } from "./TypeInstantiation"; diff --git a/generators/go-v2/ast/src/ast/utils/writeArguments.ts b/generators/go-v2/ast/src/ast/utils/writeArguments.ts index a7e1211a64c..9f2c717b9cd 100644 --- a/generators/go-v2/ast/src/ast/utils/writeArguments.ts +++ b/generators/go-v2/ast/src/ast/utils/writeArguments.ts @@ -1,18 +1,40 @@ -import { Argument, Arguments, isNamedArgument } from "@fern-api/browser-compatible-base-generator"; +import { + Argument, + Arguments, + NamedArgument, + UnnamedArgument, + isNamedArgument +} from "@fern-api/browser-compatible-base-generator"; import { TypeInstantiation } from "../TypeInstantiation"; import { Writer } from "../core/Writer"; -export function writeArguments({ writer, arguments_ }: { writer: Writer; arguments_: Arguments }): void { +export function writeArguments({ + writer, + arguments_, + multiline = true +}: { + writer: Writer; + arguments_: Arguments; + multiline?: boolean; +}): void { const filteredArguments = filterNopArguments(arguments_); if (filteredArguments.length === 0) { writer.write("()"); return; } + if (multiline) { + writeMultiline({ writer, arguments_: filteredArguments }); + return; + } + writeCompact({ writer, arguments_: filteredArguments }); +} + +function writeMultiline({ writer, arguments_ }: { writer: Writer; arguments_: Arguments }): void { writer.writeLine("("); writer.indent(); - for (const argument of filteredArguments) { + for (const argument of arguments_) { writeArgument({ writer, argument }); writer.writeLine(","); } @@ -20,6 +42,17 @@ export function writeArguments({ writer, arguments_ }: { writer: Writer; argumen writer.write(")"); } +function writeCompact({ writer, arguments_ }: { writer: Writer; arguments_: Arguments }): void { + writer.write("("); + arguments_.forEach((argument, index) => { + if (index > 0) { + writer.write(", "); + } + writeArgument({ writer, argument }); + }); + writer.write(")"); +} + function writeArgument({ writer, argument }: { writer: Writer; argument: Argument }): void { if (isNamedArgument(argument)) { writer.writeNodeOrString(argument.assignment); @@ -28,8 +61,12 @@ function writeArgument({ writer, argument }: { writer: Writer; argument: Argumen } } -function filterNopArguments(arguments_: Argument[]): Argument[] { - return arguments_.filter( +function filterNopArguments(arguments_: Arguments): Arguments { + const filtered = arguments_.filter( (argument) => !(argument instanceof TypeInstantiation && TypeInstantiation.isNop(argument)) ); + if (arguments_.length > 0 && arguments_[0] != null && "name" in arguments_[0]) { + return filtered as NamedArgument[]; + } + return filtered as UnnamedArgument[]; } diff --git a/generators/go-v2/ast/src/context/AbstractGoGeneratorContext.ts b/generators/go-v2/ast/src/context/AbstractGoGeneratorContext.ts index c33204c3d2a..126c797a46b 100644 --- a/generators/go-v2/ast/src/context/AbstractGoGeneratorContext.ts +++ b/generators/go-v2/ast/src/context/AbstractGoGeneratorContext.ts @@ -7,6 +7,8 @@ import { assertNever } from "@fern-api/core-utils"; import { RelativeFilePath } from "@fern-api/path-utils"; import { + ErrorDeclaration, + ErrorId, FernFilepath, HttpService, IntermediateRepresentation, @@ -22,10 +24,12 @@ import { } from "@fern-fern/ir-sdk/api"; import { go } from ".."; -import { TimeTypeReference, UuidTypeReference } from "../ast/Type"; +import { IoReaderTypeReference, TimeTypeReference, UuidTypeReference } from "../ast/Type"; import { BaseGoCustomConfigSchema } from "../custom-config/BaseGoCustomConfigSchema"; import { resolveRootImportPath } from "../custom-config/resolveRootImportPath"; import { GoTypeMapper } from "./GoTypeMapper"; +import { GoValueFormatter } from "./GoValueFormatter"; +import { GoZeroValueMapper } from "./GoZeroValueMapper"; export interface FileLocation { importPath: string; @@ -37,6 +41,8 @@ export abstract class AbstractGoGeneratorContext< > extends AbstractGeneratorContext { private rootImportPath: string; public readonly goTypeMapper: GoTypeMapper; + public readonly goValueFormatter: GoValueFormatter; + public readonly goZeroValueMapper: GoZeroValueMapper; public constructor( public readonly ir: IntermediateRepresentation, @@ -46,6 +52,8 @@ export abstract class AbstractGoGeneratorContext< ) { super(config, generatorNotificationService); this.goTypeMapper = new GoTypeMapper(this); + this.goValueFormatter = new GoValueFormatter(this); + this.goZeroValueMapper = new GoZeroValueMapper(this); this.rootImportPath = resolveRootImportPath({ config: this.config, customConfig: this.customConfig @@ -68,6 +76,14 @@ export abstract class AbstractGoGeneratorContext< return subpackage; } + public getErrorDeclarationOrThrow(errorId: ErrorId): ErrorDeclaration { + const errorDeclaration = this.ir.errors[errorId]; + if (errorDeclaration == null) { + throw new Error(`Error declaration with id ${errorId} not found`); + } + return errorDeclaration; + } + public getClassName(name: Name): string { return name.pascalCase.unsafeName; } @@ -80,14 +96,146 @@ export abstract class AbstractGoGeneratorContext< return `${this.rootImportPath}/core`; } + public getInternalImportPath(): string { + return `${this.rootImportPath}/internal`; + } + + public getOptionImportPath(): string { + return `${this.rootImportPath}/option`; + } + public getFieldName(name: Name): string { return name.pascalCase.unsafeName; } + public getParameterName(name: Name): string { + return name.camelCase.safeName; + } + + public maybeUnwrapIterable(typeReference: TypeReference): TypeReference | undefined { + switch (typeReference.type) { + case "container": + const container = typeReference.container; + switch (container.type) { + case "list": + return container.list; + case "set": + return container.set; + case "optional": + return this.maybeUnwrapIterable(container.optional); + case "nullable": + return this.maybeUnwrapIterable(container.nullable); + case "literal": + case "map": + return undefined; + default: + assertNever(container); + } + case "named": + const typeDeclaration = this.getTypeDeclarationOrThrow(typeReference.typeId).shape; + switch (typeDeclaration.type) { + case "alias": + return this.maybeUnwrapIterable(typeDeclaration.aliasOf); + case "enum": + case "object": + case "union": + case "undiscriminatedUnion": + return undefined; + default: + assertNever(typeDeclaration); + } + case "primitive": + case "unknown": + return undefined; + default: + assertNever(typeReference); + } + } + + public maybeUnwrapOptionalOrNullable(typeReference: TypeReference): TypeReference | undefined { + switch (typeReference.type) { + case "container": + const container = typeReference.container; + switch (container.type) { + case "optional": + return container.optional; + case "nullable": + return container.nullable; + case "list": + case "set": + case "literal": + case "map": + return undefined; + default: + assertNever(container); + } + case "named": + case "primitive": + case "unknown": + return undefined; + default: + assertNever(typeReference); + } + } + + /** + * Returns true if the type reference needs to be dereferenced to get the + * underlying type. + * + * Container types like lists, maps, and sets are already nil-able, so they + * don't require a dereference prefix. + */ + public needsOptionalDereference(typeReference: TypeReference): boolean { + switch (typeReference.type) { + case "named": + const typeDeclaration = this.getTypeDeclarationOrThrow(typeReference.typeId).shape; + switch (typeDeclaration.type) { + case "alias": + return this.needsOptionalDereference(typeDeclaration.aliasOf); + case "enum": + return true; + case "object": + case "union": + case "undiscriminatedUnion": + return false; + default: + assertNever(typeDeclaration); + } + case "primitive": + return true; + case "container": + case "unknown": + return false; + default: + assertNever(typeReference); + } + } + public getLiteralAsString(literal: Literal): string { return literal.type === "string" ? `'${literal.string}'` : literal.boolean ? "'true'" : "'false'"; } + public getContextTypeReference(): go.TypeReference { + return go.typeReference({ + name: "Context", + importPath: "context" + }); + } + + public getZeroTime(): go.TypeInstantiation { + return go.TypeInstantiation.struct({ + typeReference: TimeTypeReference, + fields: [] + }); + } + + public getZeroUuid(): go.TypeInstantiation { + return go.TypeInstantiation.struct({ + typeReference: UuidTypeReference, + fields: [] + }); + } + public getUuidTypeReference(): go.TypeReference { return UuidTypeReference; } @@ -96,6 +244,10 @@ export abstract class AbstractGoGeneratorContext< return TimeTypeReference; } + public getIoReaderTypeReference(): go.TypeReference { + return IoReaderTypeReference; + } + public isOptional(typeReference: TypeReference): boolean { switch (typeReference.type) { case "container": @@ -188,32 +340,40 @@ export abstract class AbstractGoGeneratorContext< primitive }: { typeReference: TypeReference; - primitive?: PrimitiveTypeV1; + primitive: PrimitiveTypeV1; }): boolean { + return this.maybePrimitive(typeReference) === primitive; + } + + public maybePrimitive(typeReference: TypeReference): PrimitiveTypeV1 | undefined { switch (typeReference.type) { case "container": - switch (typeReference.container.type) { + const container = typeReference.container; + switch (container.type) { case "optional": - return this.isPrimitive({ typeReference: typeReference.container.optional, primitive }); + return this.maybePrimitive(container.optional); case "nullable": - return this.isPrimitive({ typeReference: typeReference.container.nullable, primitive }); + return this.maybePrimitive(container.nullable); + case "list": + case "set": + case "literal": + case "map": + return undefined; + default: + assertNever(container); } - return false; case "named": { const declaration = this.getTypeDeclarationOrThrow(typeReference.typeId); if (declaration.shape.type === "alias") { - return this.isPrimitive({ typeReference: declaration.shape.aliasOf, primitive }); + return this.maybePrimitive(declaration.shape.aliasOf); } - return false; + return undefined; } case "primitive": { - if (primitive == null) { - return true; - } - return typeReference.primitive.v1 === primitive; + return typeReference.primitive.v1; } case "unknown": { - return false; + return undefined; } default: assertNever(typeReference); @@ -245,7 +405,15 @@ export abstract class AbstractGoGeneratorContext< } protected getFileLocation(filepath: FernFilepath, suffix?: string): FileLocation { - let parts = filepath.packagePath.map((path) => path.pascalCase.safeName.toLowerCase()); + return this.getLocation(filepath.allParts, suffix); + } + + protected getPackageLocation(filepath: FernFilepath, suffix?: string): FileLocation { + return this.getLocation(filepath.packagePath, suffix); + } + + private getLocation(names: Name[], suffix?: string): FileLocation { + let parts = names.map((name) => name.pascalCase.safeName.toLowerCase()); parts = suffix != null ? [...parts, suffix] : parts; return { importPath: [this.getRootImportPath(), ...parts].join("/"), diff --git a/generators/go-v2/ast/src/context/GoTypeMapper.ts b/generators/go-v2/ast/src/context/GoTypeMapper.ts index fb84849963e..7a0b952cce1 100644 --- a/generators/go-v2/ast/src/context/GoTypeMapper.ts +++ b/generators/go-v2/ast/src/context/GoTypeMapper.ts @@ -79,8 +79,8 @@ export class GoTypeMapper { return PrimitiveTypeV1._visit(primitive.v1, { integer: () => go.Type.int(), long: () => go.Type.int64(), - uint: () => go.Type.int(), // TODO: Add support for uint types in the Go generator. - uint64: () => go.Type.int64(), // TODO: Add support for uint64 types in the Go generator. + uint: () => go.Type.int(), + uint64: () => go.Type.int64(), float: () => go.Type.float64(), double: () => go.Type.float64(), boolean: () => go.Type.bool(), @@ -106,6 +106,6 @@ export class GoTypeMapper { } private convertNamed({ named }: { named: DeclaredTypeName }): Type { - return go.Type.reference(this.convertToTypeReference(named)); + return go.Type.pointer(go.Type.reference(this.convertToTypeReference(named))); } } diff --git a/generators/go-v2/ast/src/context/GoValueFormatter.ts b/generators/go-v2/ast/src/context/GoValueFormatter.ts new file mode 100644 index 00000000000..d4adb9e4e14 --- /dev/null +++ b/generators/go-v2/ast/src/context/GoValueFormatter.ts @@ -0,0 +1,105 @@ +import { assertNever } from "@fern-api/core-utils"; + +import { PrimitiveTypeV1, TypeReference } from "@fern-fern/ir-sdk/api"; + +import { go } from "../"; +import { BaseGoCustomConfigSchema } from "../custom-config/BaseGoCustomConfigSchema"; +import { AbstractGoGeneratorContext } from "./AbstractGoGeneratorContext"; + +export declare namespace GoValueFormatter { + interface Args { + reference: TypeReference; + value: go.AstNode; + } + + interface Result { + formatted: go.AstNode; + zeroValue: go.AstNode; + isIterable: boolean; + isOptional: boolean; + isPrimitive: boolean; + } +} + +export class GoValueFormatter { + private context: AbstractGoGeneratorContext; + + constructor(context: AbstractGoGeneratorContext) { + this.context = context; + } + + public convert({ reference, value }: GoValueFormatter.Args): GoValueFormatter.Result { + const iterableType = this.context.maybeUnwrapIterable(reference); + if (iterableType != null) { + const format = this.convert({ reference: iterableType, value }); + return { + ...format, + isIterable: true + }; + } + + let prefix = ""; + let suffix = ""; + let isOptional = false; + let isPrimitive = false; + + const optionalOrNullableType = this.context.maybeUnwrapOptionalOrNullable(reference); + if (optionalOrNullableType != null) { + if (this.context.needsOptionalDereference(optionalOrNullableType)) { + prefix = "*"; + } + isOptional = true; + } + + const primitive = this.context.maybePrimitive(reference); + if (primitive != null) { + if (isOptional) { + prefix = "*"; + } + switch (primitive) { + case PrimitiveTypeV1.DateTime: + prefix = ""; + suffix = ".Format(time.RFC3339)"; + break; + case PrimitiveTypeV1.Date: + prefix = ""; + suffix = `.Format("2006-01-02")`; + break; + case PrimitiveTypeV1.Base64: + prefix = "base64.StdEncoding.EncodeToString(" + prefix + ")"; + suffix = ")"; + break; + case PrimitiveTypeV1.Uuid: + case PrimitiveTypeV1.BigInteger: + case PrimitiveTypeV1.Integer: + case PrimitiveTypeV1.Long: + case PrimitiveTypeV1.Uint: + case PrimitiveTypeV1.Uint64: + case PrimitiveTypeV1.Float: + case PrimitiveTypeV1.Double: + case PrimitiveTypeV1.Boolean: + case PrimitiveTypeV1.String: + break; + default: + assertNever(primitive); + } + isPrimitive = true; + } + + return { + formatted: this.format({ prefix, suffix, value }), + zeroValue: this.context.goZeroValueMapper.convert({ reference }), + isIterable: false, + isOptional, + isPrimitive + }; + } + + private format({ prefix, suffix, value }: { prefix: string; suffix: string; value: go.AstNode }): go.AstNode { + return go.codeblock((writer) => { + writer.write(prefix); + writer.writeNode(value); + writer.write(suffix); + }); + } +} diff --git a/generators/go-v2/ast/src/context/GoZeroValueMapper.ts b/generators/go-v2/ast/src/context/GoZeroValueMapper.ts new file mode 100644 index 00000000000..ead93db2150 --- /dev/null +++ b/generators/go-v2/ast/src/context/GoZeroValueMapper.ts @@ -0,0 +1,88 @@ +import { assertNever } from "@fern-api/core-utils"; + +import { ContainerType, Literal, PrimitiveType, PrimitiveTypeV1, TypeReference } from "@fern-fern/ir-sdk/api"; + +import { go } from "../"; +import { TypeInstantiation } from "../ast"; +import { BaseGoCustomConfigSchema } from "../custom-config/BaseGoCustomConfigSchema"; +import { AbstractGoGeneratorContext } from "./AbstractGoGeneratorContext"; + +export declare namespace GoZeroValueMapper { + interface Args { + reference: TypeReference; + } +} + +export class GoZeroValueMapper { + private context: AbstractGoGeneratorContext; + + constructor(context: AbstractGoGeneratorContext) { + this.context = context; + } + + public convert({ reference }: GoZeroValueMapper.Args): TypeInstantiation { + switch (reference.type) { + case "container": + return this.convertContainer({ + container: reference.container + }); + case "named": + return this.convertNamed({ named: reference }); + case "primitive": + return this.convertPrimitive(reference); + case "unknown": + return go.TypeInstantiation.nil(); + default: + assertNever(reference); + } + } + + private convertContainer({ container }: { container: ContainerType }): TypeInstantiation { + switch (container.type) { + case "list": + case "map": + case "set": + case "optional": + case "nullable": + return go.TypeInstantiation.nil(); + case "literal": + return this.convertLiteral({ literal: container.literal }); + default: + assertNever(container); + } + } + + private convertPrimitive({ primitive }: { primitive: PrimitiveType }): TypeInstantiation { + return PrimitiveTypeV1._visit(primitive.v1, { + integer: () => go.TypeInstantiation.int(0), + long: () => go.TypeInstantiation.int64(0), + uint: () => go.TypeInstantiation.int(0), + uint64: () => go.TypeInstantiation.int64(0), + float: () => go.TypeInstantiation.float64(0), + double: () => go.TypeInstantiation.float64(0), + boolean: () => go.TypeInstantiation.bool(false), + string: () => go.TypeInstantiation.string(""), + date: () => this.context.getZeroTime(), + dateTime: () => this.context.getZeroTime(), + uuid: () => this.context.getZeroUuid(), + base64: () => go.TypeInstantiation.nil(), + bigInteger: () => go.TypeInstantiation.int(0), + _other: () => go.TypeInstantiation.nil() + }); + } + + private convertLiteral({ literal }: { literal: Literal }): TypeInstantiation { + switch (literal.type) { + case "boolean": + return go.TypeInstantiation.bool(false); + case "string": + return go.TypeInstantiation.string(""); + default: + assertNever(literal); + } + } + + private convertNamed({}: {}): TypeInstantiation { + return go.TypeInstantiation.nil(); + } +} diff --git a/generators/go-v2/ast/src/go.ts b/generators/go-v2/ast/src/go.ts index 3db0dc6657c..ca8cf253e14 100644 --- a/generators/go-v2/ast/src/go.ts +++ b/generators/go-v2/ast/src/go.ts @@ -9,6 +9,8 @@ import { Method, MethodInvocation, Parameter, + Pointer, + Selector, Struct } from "./ast"; @@ -48,6 +50,14 @@ export function parameter(args: Parameter.Args): Parameter { return new Parameter(args); } +export function pointer(args: Pointer.Args): Pointer { + return new Pointer(args); +} + +export function selector(args: Selector.Args): Selector { + return new Selector(args); +} + export function struct(args: Struct.Args): Struct { return new Struct(args); } @@ -68,6 +78,8 @@ export { Method, MethodInvocation, Parameter, + Pointer, + Selector, Struct, Type, TypeInstantiation, diff --git a/generators/go-v2/base/src/project/GoProject.ts b/generators/go-v2/base/src/project/GoProject.ts index eb5f4f099cf..c4024fc1f29 100644 --- a/generators/go-v2/base/src/project/GoProject.ts +++ b/generators/go-v2/base/src/project/GoProject.ts @@ -3,6 +3,7 @@ import { mkdir } from "fs/promises"; import { AbstractProject, File } from "@fern-api/base-generator"; import { AbsoluteFilePath } from "@fern-api/fs-utils"; import { AbstractGoGeneratorContext, BaseGoCustomConfigSchema } from "@fern-api/go-ast"; +import { loggingExeca } from "@fern-api/logging-execa"; /** * In memory representation of a Go project. @@ -28,6 +29,10 @@ export class GoProject extends AbstractProject { + await file.write(this.absolutePathToOutputDirectory); + } + private async writeGoFiles({ absolutePathToDirectory, files @@ -38,15 +43,10 @@ export class GoProject extends AbstractProject await file.write(absolutePathToDirectory))); if (files.length > 0) { - // TODO: Uncomment this once the go-v2 generator is responsible for producing the go.mod file. - // Otherwise, we get a "directory prefix . does not contain main module or its selected dependencies" error. - // - // --- - // - // await loggingExeca(this.context.logger, "go", ["fmt", "./..."], { - // doNotPipeOutput: true, - // cwd: absolutePathToDirectory - // }); + await loggingExeca(this.context.logger, "go", ["fmt", "./..."], { + doNotPipeOutput: true, + cwd: absolutePathToDirectory + }); } return absolutePathToDirectory; } diff --git a/generators/go-v2/sdk/src/SdkGeneratorCli.ts b/generators/go-v2/sdk/src/SdkGeneratorCli.ts index 1ef0622291c..6c780beb9f8 100644 --- a/generators/go-v2/sdk/src/SdkGeneratorCli.ts +++ b/generators/go-v2/sdk/src/SdkGeneratorCli.ts @@ -1,9 +1,5 @@ -import urlJoin from "url-join"; - import { File, GeneratorNotificationService } from "@fern-api/base-generator"; -import { FernIr } from "@fern-api/dynamic-ir-sdk"; import { RelativeFilePath } from "@fern-api/fs-utils"; -import { go } from "@fern-api/go-ast"; import { AbstractGoGeneratorCli } from "@fern-api/go-base"; import { DynamicSnippetsGenerator } from "@fern-api/go-dynamic-snippets"; @@ -13,6 +9,8 @@ import { IntermediateRepresentation } from "@fern-fern/ir-sdk/api"; import { SdkCustomConfigSchema } from "./SdkCustomConfig"; import { SdkGeneratorContext } from "./SdkGeneratorContext"; +import { ModuleConfigWriter } from "./module/ModuleConfigWriter"; +import { RawClientGenerator } from "./raw-client/RawClientGenerator"; import { convertDynamicEndpointSnippetRequest } from "./utils/convertEndpointSnippetRequest"; import { convertIr } from "./utils/convertIr"; import { WireTestGenerator } from "./wiretest/WireTestGenerator"; @@ -53,6 +51,9 @@ export class SdkGeneratorCLI extends AbstractGoGeneratorCli { + this.writeGoMod(context); + this.generateRawClients(context); + if (this.shouldGenerateReadme(context)) { try { const endpointSnippets = this.generateSnippets({ context }); @@ -64,35 +65,57 @@ export class SdkGeneratorCLI extends AbstractGoGeneratorCli { - public readonly generatorAgent: GoGeneratorAgent; public readonly project: GoProject; + public readonly caller: Caller; + public readonly endpointGenerator: EndpointGenerator; + public readonly generatorAgent: GoGeneratorAgent; public constructor( public readonly ir: IntermediateRepresentation, @@ -21,6 +41,8 @@ export class SdkGeneratorContext extends AbstractGoGeneratorContext environment.id === id)?.url; + case "multipleBaseUrls": { + for (const environment of environments.environments) { + const url = environment.urls[id]; + if (url != null) { + return url; + } + } + return undefined; + } + default: + assertNever(environments); + } + } + + public getModuleConfig({ outputMode }: { outputMode: OutputMode }): ModuleConfig | undefined { + const githubConfig = this.getGithubOutputMode({ outputMode }); + if (githubConfig == null && this.customConfig.module == null) { + return undefined; + } + if (githubConfig == null) { + return this.customConfig.module; + } + if (this.customConfig.module == null) { + // A GitHub configuration was provided, so the module config should use + // the GitHub configuration's repository url. + const modulePath = githubConfig.repoUrl.replace("https://", ""); + return { + ...ModuleConfig.DEFAULT, + path: modulePath + }; + } + return { + path: this.customConfig.module.path, + version: this.customConfig.module.version, + imports: this.customConfig.module.imports ?? ModuleConfig.DEFAULT.imports + }; + } + + private getGithubOutputMode({ outputMode }: { outputMode: OutputMode }): GithubOutputMode | undefined { + switch (outputMode.type) { + case "github": + return outputMode; + case "publish": + case "downloadFiles": + return undefined; + default: + assertNever(outputMode); + } + } + + public getRootClientDirectory(): RelativeFilePath { + return RelativeFilePath.of(this.getClientPackageName()); + } + + public getRootClientImportPath(): string { + return `${this.getRootImportPath()}/${this.getClientPackageName()}`; + } + + public getRootClientClassReference(): go.TypeReference { + return go.typeReference({ + name: this.getClientClassName(), + importPath: this.getRootClientImportPath() + }); + } + + public getRootRawClientClassReference(): go.TypeReference { + return go.typeReference({ + name: this.getRawClientClassName(), + importPath: this.getRootClientImportPath() + }); + } + + public getSubpackageClientClassReference(subpackage: Subpackage): go.TypeReference { + return go.typeReference({ + name: this.getClientClassName(), + importPath: this.getSubpackageClientFileLocation(subpackage).importPath + }); + } + + public getSubpackageRawClientClassReference(subpackage: Subpackage): go.TypeReference { + return go.typeReference({ + name: this.getRawClientClassName(), + importPath: this.getSubpackageClientFileLocation(subpackage).importPath + }); + } + + public getSubpackageClientPackageName(subpackage: Subpackage): string { + return this.getFileLocation(subpackage.fernFilepath).importPath.split("/").pop() ?? ""; + } + + public getSubpackageClientFileLocation(subpackage: Subpackage): FileLocation { + // TODO: Add support for conditionally including the nested 'client' package element. + return this.getFileLocation(subpackage.fernFilepath); + } + + public getSubpackageClientField(subpackage: Subpackage): go.Field { + return go.field({ + name: this.getClientClassName(), + type: go.Type.reference(this.getSubpackageClientClassReference(subpackage)) + }); + } + + public shouldGenerateSubpackageClient(subpackage: Subpackage): boolean { + if (subpackage.service != null) { + return true; + } + for (const subpackageId of subpackage.subpackages) { + const subpackage = this.getSubpackageOrThrow(subpackageId); + if (this.shouldGenerateSubpackageClient(subpackage)) { + return true; + } + } + return false; + } + + public getContextParameter(): go.Parameter { + return go.parameter({ + name: "ctx", + type: go.Type.reference(this.getContextTypeReference()) + }); + } + + public getContextParameterReference(): go.AstNode { + return go.codeblock("ctx"); + } + + public getErrorCodesTypeReference(): go.TypeReference { + return go.typeReference({ + name: "ErrorCodes", + importPath: this.getInternalImportPath() + }); + } + + public getCoreApiErrorTypeReference(): go.TypeReference { + return go.typeReference({ + name: "APIError", + importPath: this.getCoreImportPath() + }); + } + + public getVariadicRequestOptionParameter(): go.Parameter { + return go.parameter({ + name: "opts", + type: this.getVariadicRequestOptionType() + }); + } + + public getVariadicIdempotentRequestOptionParameter(): go.Parameter { + return go.parameter({ + name: "opts", + type: this.getVariadicIdempotentRequestOptionType() + }); + } + + public getVariadicRequestOptionType(): go.Type { + return go.Type.variadic(go.Type.reference(this.getRequestOptionTypeReference())); + } + + public getVariadicIdempotentRequestOptionType(): go.Type { + return go.Type.variadic(go.Type.reference(this.getIdempotentRequestOptionTypeReference())); + } + + public getRequestOptionTypeReference(): go.TypeReference { + return go.typeReference({ + name: "RequestOption", + importPath: this.getOptionImportPath() + }); + } + + public getIdempotentRequestOptionTypeReference(): go.TypeReference { + return go.typeReference({ + name: "IdempotentRequestOption", + importPath: this.getOptionImportPath() + }); + } + + public callBytesNewBuffer(): go.FuncInvocation { + return go.invokeFunc({ + func: go.typeReference({ name: "NewBuffer", importPath: "bytes" }), + arguments_: [go.codeblock("nil")], + multiline: false + }); + } + + public callNewRequestOptions(argument: go.AstNode): go.FuncInvocation { + return go.invokeFunc({ + func: go.typeReference({ + name: "NewRequestOptions", + importPath: this.getCoreImportPath() + }), + arguments_: [argument], + multiline: false + }); + } + + public callNewIdempotentRequestOptions(argument: go.AstNode): go.FuncInvocation { + return go.invokeFunc({ + func: go.typeReference({ + name: "NewIdempotentRequestOptions", + importPath: this.getCoreImportPath() + }), + arguments_: [argument], + multiline: false + }); + } + + public callEncodeUrl(arguments_: go.AstNode[]): go.FuncInvocation { + return this.callInternalFunc("EncodeURL", arguments_); + } + + public callResolveBaseURL(arguments_: go.AstNode[]): go.FuncInvocation { + return this.callInternalFunc("ResolveBaseURL", arguments_); + } + + public callQueryValues(arguments_: go.AstNode[]): go.FuncInvocation { + return this.callInternalFunc("QueryValues", arguments_); + } + + public callMergeHeaders(arguments_: go.AstNode[]): go.FuncInvocation { + return this.callInternalFunc("MergeHeaders", arguments_); + } + + public getRawResponseTypeReference(valueType: go.Type): go.TypeReference { + return go.typeReference({ + name: "Response", + importPath: this.getCoreImportPath(), + generics: [valueType] + }); + } + + public getStreamTypeReference(valueType: go.Type): go.TypeReference { + return go.typeReference({ + name: "Stream", + importPath: this.getCoreImportPath(), + generics: [valueType] + }); + } + + public getStreamPayload(streamingResponse: StreamingResponse): go.Type { + switch (streamingResponse.type) { + case "json": + case "sse": + return this.goTypeMapper.convert({ reference: streamingResponse.payload }); + case "text": + return go.Type.string(); + default: + assertNever(streamingResponse); + } + } + + public getRequestWrapperTypeReference(serviceId: ServiceId, requestName: Name): go.TypeReference { + return go.typeReference({ + name: this.getClassName(requestName), + importPath: this.getLocationForWrappedRequest(serviceId).importPath + }); + } + + public getEndpointRequestType({ + endpoint, + serviceId + }: { + endpoint: HttpEndpoint; + serviceId: ServiceId; + }): go.Type | undefined { + const sdkRequest = endpoint.sdkRequest; + if (sdkRequest == null) { + return undefined; + } + switch (sdkRequest.shape.type) { + case "justRequestBody": + return this.getEndpointRequestBodyType(sdkRequest.shape.value); + case "wrapper": { + const location = this.getLocationForWrappedRequest(serviceId); + return go.Type.pointer( + go.Type.reference( + go.typeReference({ + name: this.getClassName(sdkRequest.shape.wrapperName), + importPath: location.importPath + }) + ) + ); + } + default: + assertNever(sdkRequest.shape); + } + } + + public shouldSkipWrappedRequest({ + endpoint, + wrapper + }: { + endpoint: HttpEndpoint; + wrapper: SdkRequestWrapper; + }): boolean { + return ( + (wrapper.onlyPathParameters ?? false) && !this.includePathParametersInWrappedRequest({ endpoint, wrapper }) + ); + } + + public includePathParametersInWrappedRequest({ + endpoint, + wrapper + }: { + endpoint: HttpEndpoint; + wrapper: SdkRequestWrapper; + }): boolean { + const inlinePathParameters = this.customConfig.inlinePathParameters; + if (inlinePathParameters == null) { + return false; + } + const wrapperShouldIncludePathParameters = wrapper.includePathParameters ?? false; + return endpoint.allPathParameters.length > 0 && inlinePathParameters && wrapperShouldIncludePathParameters; + } + + public accessRequestProperty({ + requestParameterName, + propertyName + }: { + requestParameterName: Name; + propertyName: Name; + }): string { + const requestParameter = this.getParameterName(requestParameterName); + return `${requestParameter}.${this.getFieldName(propertyName)}`; + } + + public getNetHttpHeaderTypeReference(): go.TypeReference { + return go.typeReference({ + name: "Header", + importPath: "net/http" + }); + } + + public getNetHttpMethodTypeReference(method: HttpMethod): go.TypeReference { + return go.typeReference({ + name: this.getNetHttpMethodTypeReferenceName(method), + importPath: "net/http" + }); + } + + private getEndpointRequestBodyType(requestBodyType: SdkRequestBodyType): go.Type { + switch (requestBodyType.type) { + case "typeReference": + return this.goTypeMapper.convert({ reference: requestBodyType.requestBodyType }); + case "bytes": { + return go.Type.reference(this.getIoReaderTypeReference()); + } + default: + assertNever(requestBodyType); + } + } + + private callInternalFunc(name: string, arguments_: go.AstNode[]): go.FuncInvocation { + return go.invokeFunc({ + func: go.typeReference({ + name, + importPath: this.getInternalImportPath() + }), + arguments_ + }); + } + + private getNetHttpMethodTypeReferenceName(method: HttpMethod): string { + switch (method) { + case "GET": + return "MethodGet"; + case "POST": + return "MethodPost"; + case "PUT": + return "MethodPut"; + case "PATCH": + return "MethodPatch"; + case "DELETE": + return "MethodDelete"; + case "HEAD": + return "MethodHead"; + default: + assertNever(method); + } + } + + private getLocationForWrappedRequest(serviceId: ServiceId): FileLocation { + const httpService = this.getHttpServiceOrThrow(serviceId); + return this.getPackageLocation(httpService.name.fernFilepath); + } } diff --git a/generators/go-v2/sdk/src/endpoint/AbstractEndpointGenerator.ts b/generators/go-v2/sdk/src/endpoint/AbstractEndpointGenerator.ts new file mode 100644 index 00000000000..c15a33bcec0 --- /dev/null +++ b/generators/go-v2/sdk/src/endpoint/AbstractEndpointGenerator.ts @@ -0,0 +1,136 @@ +import { assertNever } from "@fern-api/core-utils"; +import { go } from "@fern-api/go-ast"; + +import { HttpEndpoint, HttpService, PathParameter, SdkRequest, ServiceId } from "@fern-fern/ir-sdk/api"; + +import { SdkGeneratorContext } from "../SdkGeneratorContext"; +import { EndpointSignatureInfo } from "./EndpointSignatureInfo"; +import { EndpointRequest } from "./request/EndpointRequest"; +import { getEndpointRequest } from "./utils/getEndpointRequest"; +import { getEndpointReturnTypes } from "./utils/getEndpointReturnTypes"; +import { getEndpointReturnZeroValues } from "./utils/getEndpointReturnZeroValue"; +import { getRawEndpointReturnTypeReference } from "./utils/getRawEndpointReturnTypeReference"; + +export abstract class AbstractEndpointGenerator { + protected readonly context: SdkGeneratorContext; + + public constructor({ context }: { context: SdkGeneratorContext }) { + this.context = context; + } + + public getEndpointSignatureInfo({ + serviceId, + service, + endpoint + }: { + serviceId: ServiceId; + service: HttpService; + endpoint: HttpEndpoint; + }): EndpointSignatureInfo { + const { pathParameters, pathParameterReferences } = this.getAllPathParameters({ serviceId, endpoint }); + const request = getEndpointRequest({ context: this.context, endpoint, serviceId, service }); + const requestParameter = request != null ? this.getRequestParameter({ request }) : undefined; + const allParameters = [ + this.context.getContextParameter(), + ...pathParameters, + requestParameter, + endpoint.idempotent + ? this.context.getVariadicIdempotentRequestOptionParameter() + : this.context.getVariadicRequestOptionParameter() + ].filter((p): p is go.Parameter => p != null); + const returnType = getEndpointReturnTypes({ context: this.context, endpoint }); + const rawReturnTypeReference = getRawEndpointReturnTypeReference({ context: this.context, endpoint }); + const returnZeroValue = getEndpointReturnZeroValues({ context: this.context, endpoint }); + return { + allParameters, + pathParameters, + pathParameterReferences, + request, + requestParameter, + returnType, + rawReturnTypeReference, + returnZeroValue + }; + } + + private getAllPathParameters({ + serviceId, + endpoint + }: { + serviceId: ServiceId; + endpoint: HttpEndpoint; + }): Pick { + const service = this.context.getHttpServiceOrThrow(serviceId); + const includePathParametersInSignature = this.includePathParametersInEndpointSignature({ endpoint }); + const pathParameters: go.Parameter[] = []; + const pathParameterReferences: Record = {}; + for (const pathParam of [ + ...this.context.ir.pathParameters, + ...service.pathParameters, + ...endpoint.pathParameters + ]) { + const parameterName = this.context.getParameterName(pathParam.name); + pathParameterReferences[pathParam.name.originalName] = this.accessPathParameterValue({ + pathParameter: pathParam, + sdkRequest: endpoint.sdkRequest, + includePathParametersInEndpointSignature: includePathParametersInSignature + }); + if (includePathParametersInSignature) { + pathParameters.push( + go.parameter({ + docs: pathParam.docs, + name: parameterName, + type: this.context.goTypeMapper.convert({ reference: pathParam.valueType }) + }) + ); + } + } + return { + pathParameters, + pathParameterReferences + }; + } + + private getRequestParameter({ request }: { request: EndpointRequest }): go.Parameter { + return go.parameter({ + type: request.getRequestParameterType(), + name: request.getRequestParameterName() + }); + } + + private includePathParametersInEndpointSignature({ endpoint }: { endpoint: HttpEndpoint }): boolean { + const shape = endpoint.sdkRequest?.shape; + if (shape == null) { + return true; + } + switch (shape.type) { + case "wrapper": { + return !this.context.includePathParametersInWrappedRequest({ endpoint, wrapper: shape }); + } + case "justRequestBody": { + return true; + } + default: { + assertNever(shape); + } + } + } + + private accessPathParameterValue({ + sdkRequest, + pathParameter, + includePathParametersInEndpointSignature + }: { + sdkRequest: SdkRequest | undefined; + pathParameter: PathParameter; + includePathParametersInEndpointSignature: boolean; + }): string { + if (sdkRequest == null || includePathParametersInEndpointSignature) { + return this.context.getParameterName(pathParameter.name); + } + return this.context.accessRequestProperty({ + requestParameterName: sdkRequest.requestParameterName, + propertyName: pathParameter.name + }); + } +} diff --git a/generators/go-v2/sdk/src/endpoint/EndpointGenerator.ts b/generators/go-v2/sdk/src/endpoint/EndpointGenerator.ts new file mode 100644 index 00000000000..6f4b69ce2cf --- /dev/null +++ b/generators/go-v2/sdk/src/endpoint/EndpointGenerator.ts @@ -0,0 +1,54 @@ +import { go } from "@fern-api/go-ast"; + +import { HttpEndpoint, HttpService, ServiceId, Subpackage } from "@fern-fern/ir-sdk/api"; + +import { SdkGeneratorContext } from "../SdkGeneratorContext"; +import { AbstractEndpointGenerator } from "./AbstractEndpointGenerator"; +import { HttpEndpointGenerator } from "./http/HttpEndpointGenerator"; + +export class EndpointGenerator extends AbstractEndpointGenerator { + private http: HttpEndpointGenerator; + + public constructor(context: SdkGeneratorContext) { + super({ context }); + this.http = new HttpEndpointGenerator({ context }); + } + + public generate({ + serviceId, + service, + subpackage, + endpoint + }: { + serviceId: ServiceId; + service: HttpService; + subpackage: Subpackage | undefined; + endpoint: HttpEndpoint; + }): go.Method[] { + return this.http.generate({ + serviceId, + service, + subpackage, + endpoint + }); + } + + public generateRaw({ + serviceId, + service, + subpackage, + endpoint + }: { + serviceId: ServiceId; + service: HttpService; + subpackage: Subpackage | undefined; + endpoint: HttpEndpoint; + }): go.Method[] { + return this.http.generateRaw({ + serviceId, + service, + subpackage, + endpoint + }); + } +} diff --git a/generators/go-v2/sdk/src/endpoint/EndpointSignatureInfo.ts b/generators/go-v2/sdk/src/endpoint/EndpointSignatureInfo.ts new file mode 100644 index 00000000000..31a1fe8c079 --- /dev/null +++ b/generators/go-v2/sdk/src/endpoint/EndpointSignatureInfo.ts @@ -0,0 +1,17 @@ +import { go } from "@fern-api/go-ast"; + +import { EndpointRequest } from "./request/EndpointRequest"; + +export interface EndpointSignatureInfo { + allParameters: go.Parameter[]; + pathParameters: go.Parameter[]; + pathParameterReferences: Record; + request: EndpointRequest | undefined; + requestParameter: go.Parameter | undefined; + rawReturnTypeReference: go.TypeReference; + + // All endpoints return an error by default; these fields are only set + // if the endpoint returns a non-error value. + returnType: go.Type | undefined; + returnZeroValue: go.TypeInstantiation | undefined; +} diff --git a/generators/go-v2/sdk/src/endpoint/http/HttpEndpointGenerator.ts b/generators/go-v2/sdk/src/endpoint/http/HttpEndpointGenerator.ts new file mode 100644 index 00000000000..bd5ae8e9cd4 --- /dev/null +++ b/generators/go-v2/sdk/src/endpoint/http/HttpEndpointGenerator.ts @@ -0,0 +1,641 @@ +import { write } from "fs"; + +import { assertNever } from "@fern-api/core-utils"; +import { go } from "@fern-api/go-ast"; + +import { + HttpEndpoint, + HttpRequestBody, + HttpResponseBody, + HttpService, + JsonResponse, + SdkRequestBodyType, + SdkRequestWrapper, + ServiceId, + StreamingResponse, + Subpackage +} from "@fern-fern/ir-sdk/api"; + +import { SdkGeneratorContext } from "../../SdkGeneratorContext"; +import { AbstractEndpointGenerator } from "../AbstractEndpointGenerator"; +import { EndpointSignatureInfo } from "../EndpointSignatureInfo"; +import { EndpointRequest } from "../request/EndpointRequest"; +import { getEndpointRequest } from "../utils/getEndpointRequest"; + +export declare namespace EndpointGenerator { + const OCTET_STREAM_CONTENT_TYPE = "application/octet-stream"; + + export interface Args { + endpoint: HttpEndpoint; + } +} + +export class HttpEndpointGenerator extends AbstractEndpointGenerator { + public constructor({ context }: { context: SdkGeneratorContext }) { + super({ context }); + } + + public generate({ + serviceId, + service, + subpackage, + endpoint + }: { + serviceId: ServiceId; + service: HttpService; + subpackage: Subpackage | undefined; + endpoint: HttpEndpoint; + }): go.Method[] { + const methods: go.Method[] = []; + return methods; + } + + public generateRaw({ + serviceId, + service, + subpackage, + endpoint + }: { + serviceId: ServiceId; + service: HttpService; + subpackage: Subpackage | undefined; + endpoint: HttpEndpoint; + }): go.Method[] { + const endpointRequest = getEndpointRequest({ context: this.context, endpoint, serviceId, service }); + return [this.generateRawUnaryEndpoint({ serviceId, service, endpoint, subpackage, endpointRequest })]; + } + + private generateRawUnaryEndpoint({ + serviceId, + service, + endpoint, + subpackage, + endpointRequest + }: { + serviceId: ServiceId; + service: HttpService; + endpoint: HttpEndpoint; + subpackage: Subpackage | undefined; + endpointRequest: EndpointRequest | undefined; + }): go.Method { + const signature = this.getEndpointSignatureInfo({ serviceId, service, endpoint }); + return new go.Method({ + name: this.context.getMethodName(endpoint.name), + parameters: signature.allParameters, + return_: this.getRawReturnSignature({ signature }), + body: this.getRawUnaryEndpointBody({ signature, endpoint, endpointRequest }), + typeReference: this.getRawClientTypeReference({ subpackage }) + }); + } + + private getRawReturnSignature({ signature }: { signature: EndpointSignatureInfo }): go.Type[] { + return [go.Type.pointer(go.Type.reference(signature.rawReturnTypeReference)), go.Type.error()]; + } + + private getRawClientTypeReference({ subpackage }: { subpackage: Subpackage | undefined }): go.TypeReference { + if (subpackage == null) { + return this.context.getRootRawClientClassReference(); + } + return this.context.getSubpackageRawClientClassReference(subpackage); + } + + private getRawUnaryEndpointBody({ + signature, + endpoint, + endpointRequest + }: { + signature: EndpointSignatureInfo; + endpoint: HttpEndpoint; + endpointRequest: EndpointRequest | undefined; + }): go.CodeBlock { + return go.codeblock((writer) => { + writer.writeNode(this.buildRequestOptions({ endpoint })); + writer.newLine(); + writer.writeNode(this.buildBaseUrl({ endpoint })); + writer.newLine(); + writer.writeNode(this.buildEndpointUrl({ endpoint, signature })); + + const buildQueryParameters = this.buildQueryParameters({ signature, endpoint, endpointRequest }); + if (buildQueryParameters != null) { + writer.newLine(); + writer.writeNode(buildQueryParameters); + } + + const buildHeaders = this.buildHeaders({ endpoint }); + writer.newLine(); + writer.writeNode(buildHeaders); + + const buildErrorDecoder = this.buildErrorDecoder({ endpoint }); + if (buildErrorDecoder != null) { + writer.newLine(); + writer.writeNode(buildErrorDecoder); + } + + const responseInitialization = this.getResponseInitialization({ endpoint }); + if (responseInitialization != null) { + writer.newLine(); + writer.writeNode(responseInitialization); + } + + writer.newLine(); + writer.write("raw, err := "); + writer.writeNode( + this.context.caller.call({ + endpoint, + clientReference: this.getCallerFieldReference(), + optionsReference: go.codeblock("options"), + url: go.codeblock("endpointURL"), + request: endpointRequest?.getRequestReference(), + response: this.getResponseParameterReference({ endpoint }) + }) + ); + writer.newLine(); + writer.writeLine("if err != nil {"); + writer.indent(); + writer.writeNode(this.writeRawReturnZeroValueWithError()); + writer.newLine(); + writer.dedent(); + writer.writeLine("}"); + + writer.writeNode(this.getRawResponseReturnStatement({ endpoint, signature })); + }); + } + + private buildRequestOptions({ endpoint }: { endpoint: HttpEndpoint }): go.CodeBlock { + const requestOptions = endpoint.idempotent + ? this.context.callNewIdempotentRequestOptions(go.codeblock("opts...")) + : this.context.callNewRequestOptions(go.codeblock("opts...")); + + return go.codeblock((writer) => { + writer.write("options := "); + writer.writeNode(requestOptions); + }); + } + + private buildBaseUrl({ endpoint }: { endpoint: HttpEndpoint }): go.CodeBlock { + return go.codeblock((writer) => { + writer.write("baseURL := "); + writer.writeNode( + this.context.callResolveBaseURL([ + go.selector({ + on: go.codeblock("options"), + selector: go.codeblock("BaseURL") + }), + go.selector({ + on: this.getRawClientReceiverCodeBlock(), + selector: go.codeblock("baseURL") + }), + this.context.getDefaultBaseUrlTypeInstantiation(endpoint) + ]) + ); + }); + } + + private buildEndpointUrl({ + endpoint, + signature + }: { + endpoint: HttpEndpoint; + signature: EndpointSignatureInfo; + }): go.CodeBlock { + const pathSuffix = this.getPathSuffix({ endpoint }); + const baseUrl = pathSuffix.length === 0 ? "baseURL" : `baseURL + "/${pathSuffix}"`; + return go.codeblock((writer) => { + writer.write("endpointURL := "); + if (endpoint.allPathParameters.length === 0) { + writer.write(baseUrl); + return; + } + const pathParameterReferences: go.AstNode[] = []; + for (const pathParameter of endpoint.allPathParameters) { + const pathParameterReference = signature.pathParameterReferences[pathParameter.name.originalName]; + if (pathParameterReference == null) { + continue; + } + pathParameterReferences.push(go.codeblock(pathParameterReference)); + } + writer.writeNode(this.context.callEncodeUrl([go.codeblock(baseUrl), ...pathParameterReferences])); + }); + } + + private buildQueryParameters({ + signature, + endpoint, + endpointRequest + }: { + signature: EndpointSignatureInfo; + endpoint: HttpEndpoint; + endpointRequest: EndpointRequest | undefined; + }): go.CodeBlock | undefined { + if (endpointRequest == null || endpoint.queryParameters.length === 0) { + return undefined; + } + return go.codeblock((writer) => { + writer.write("queryParams, err := "); + writer.writeNode(this.context.callQueryValues([go.codeblock(endpointRequest.getRequestParameterName())])); + writer.write("if err != nil {"); + writer.indent(); + writer.writeNode(this.writeRawReturnZeroValueWithError()); + writer.dedent(); + writer.write("}"); + for (const queryParameter of endpoint.queryParameters) { + const literal = this.context.maybeLiteral(queryParameter.valueType); + if (literal != null) { + writer.write( + `queryParams.Add("${queryParameter.name}", ${this.context.getLiteralAsString(literal)})` + ); + continue; + } + } + writer.write("if len(queryParams) > 0 {"); + writer.indent(); + writer.write(`endpointURL += "?" + queryParams.Encode()`); + writer.dedent(); + writer.write("}"); + }); + } + + private buildHeaders({ endpoint }: { endpoint: HttpEndpoint }): go.CodeBlock { + return go.codeblock((writer) => { + writer.write("headers := "); + writer.writeNode( + this.context.callMergeHeaders([ + go.codeblock(`${this.getRawClientReceiver()}.header.Clone()`), + go.codeblock("options.ToHeader()") + ]) + ); + if (endpoint.headers.length > 0) { + writer.newLine(); + } + for (const header of endpoint.headers) { + const literal = this.context.maybeLiteral(header.valueType); + if (literal != null) { + writer.writeNode( + this.addHeaderValue({ + wireValue: header.name.wireValue, + value: go.codeblock(this.context.getLiteralAsString(literal)) + }) + ); + continue; + } + const headerField = go.codeblock( + `${this.getRequestParameterName({ endpoint })}.${this.context.getFieldName(header.name.name)}` + ); + const format = this.context.goValueFormatter.convert({ + reference: header.valueType, + value: headerField + }); + if (format.isOptional) { + writer.write(`if ${headerField} != nil {`); + writer.indent(); + writer.writeNode( + this.addHeaderValue({ wireValue: header.name.wireValue, value: format.formatted }) + ); + writer.dedent(); + writer.write("}"); + continue; + } + writer.writeNode(this.addHeaderValue({ wireValue: header.name.wireValue, value: format.formatted })); + } + const acceptHeader = this.getAcceptHeaderValue({ endpoint }); + if (acceptHeader != null) { + writer.writeNode(this.setHeaderValue({ wireValue: "Accept", value: acceptHeader })); + } + const contentTypeHeader = this.getContentTypeHeaderValue({ endpoint }); + if (contentTypeHeader != null) { + writer.writeNode(this.setHeaderValue({ wireValue: "Content-Type", value: contentTypeHeader })); + } + }); + } + + private buildErrorDecoder({ endpoint }: { endpoint: HttpEndpoint }): go.CodeBlock | undefined { + if (endpoint.errors.length === 0) { + return undefined; + } + return go.codeblock((writer) => { + writer.write("errorCodes := "); + writer.writeNode( + go.TypeInstantiation.struct({ + typeReference: this.context.getErrorCodesTypeReference(), + fields: endpoint.errors.map((error) => { + const errorDeclaration = this.context.getErrorDeclarationOrThrow(error.error.errorId); + const errorTypeReference = go.typeReference({ + name: this.context.getClassName(errorDeclaration.name.name), + importPath: this.context.getLocationForTypeId(errorDeclaration.name.errorId).importPath + }); + return { + name: errorDeclaration.statusCode.toString(), + value: go.TypeInstantiation.reference( + go.func({ + parameters: [ + go.parameter({ + name: "apiError", + type: go.Type.reference(this.context.getCoreApiErrorTypeReference()) + }) + ], + return_: [go.Type.reference(errorTypeReference)], + body: go.codeblock((writer) => { + writer.write("return "); + writer.writeNode( + go.TypeInstantiation.structPointer({ + typeReference: this.context.getCoreApiErrorTypeReference(), + fields: [ + { + name: "APIError", + value: go.TypeInstantiation.reference(go.codeblock("apiError")) + } + ] + }) + ); + }) + }) + ) + }; + }) + }) + ); + }); + } + + private getResponseInitialization({ endpoint }: { endpoint: HttpEndpoint }): go.CodeBlock | undefined { + const responseBody = endpoint.response?.body; + if (responseBody == null) { + return undefined; + } + switch (responseBody.type) { + case "json": + return go.codeblock((writer) => { + writer.write("var response "); + writer.writeNode( + this.context.goTypeMapper.convert({ reference: responseBody.value.responseBodyType }) + ); + }); + case "fileDownload": + case "text": + return go.codeblock((writer) => { + writer.write("response := "); + writer.writeNode(this.context.callBytesNewBuffer()); + }); + case "streaming": + case "streamParameter": + // TODO: Implement stream responses. + return undefined; + case "bytes": + return undefined; + default: + assertNever(responseBody); + } + } + + private getResponseParameterReference({ endpoint }: { endpoint: HttpEndpoint }): go.CodeBlock | undefined { + const responseBody = endpoint.response?.body; + if (responseBody == null) { + return undefined; + } + switch (responseBody.type) { + case "json": + return go.codeblock("&response"); + case "bytes": + case "fileDownload": + case "text": + case "streaming": + case "streamParameter": + return go.codeblock("response"); + default: + assertNever(responseBody); + } + } + + private getRawResponseReturnStatement({ + endpoint, + signature + }: { + endpoint: HttpEndpoint; + signature: EndpointSignatureInfo; + }): go.CodeBlock { + const responseBody = endpoint.response?.body; + const responseBodyReference = + responseBody == null ? go.TypeInstantiation.nil() : this.getResponseBodyReference({ responseBody }); + return go.codeblock((writer) => { + writer.write("return "); + if (signature.rawReturnTypeReference != null) { + writer.writeNode( + this.wrapWithRawResponseType({ + context: this.context, + rawReturnTypeReference: signature.rawReturnTypeReference, + responseBodyReference + }) + ); + writer.write(", "); + } + writer.write("nil"); + }); + } + + private getResponseBodyReference({ responseBody }: { responseBody: HttpResponseBody }): go.CodeBlock { + switch (responseBody.type) { + case "json": + return this.getResponseBodyReferenceForJson({ jsonResponse: responseBody.value }); + case "bytes": + case "fileDownload": + return go.codeblock("response"); + case "text": + return go.codeblock("response.String()"); + case "streaming": + case "streamParameter": + // TODO: Implement stream responses. + return go.codeblock("nil"); + default: + assertNever(responseBody); + } + } + + private getResponseBodyReferenceForJson({ jsonResponse }: { jsonResponse: JsonResponse }): go.CodeBlock { + switch (jsonResponse.type) { + case "response": + return go.codeblock("response"); + case "nestedPropertyAsResponse": + const responseProperty = jsonResponse.responseProperty; + if (responseProperty == null) { + return go.codeblock("response"); + } + return go.codeblock(`response.${this.context.getFieldName(responseProperty.name.name)}`); + default: + assertNever(jsonResponse); + } + } + + private wrapWithRawResponseType({ + context, + rawReturnTypeReference, + responseBodyReference + }: { + context: SdkGeneratorContext; + rawReturnTypeReference: go.TypeReference; + responseBodyReference: go.AstNode; + }): go.TypeInstantiation { + return go.TypeInstantiation.structPointer({ + typeReference: rawReturnTypeReference, + fields: [ + { + name: "StatusCode", + value: go.TypeInstantiation.reference(go.codeblock("raw.StatusCode")) + }, + { + name: "Header", + value: go.TypeInstantiation.reference(go.codeblock("raw.Header")) + }, + { + name: "Body", + value: go.TypeInstantiation.reference(responseBodyReference) + } + ] + }); + } + + private getPathSuffix({ endpoint }: { endpoint: HttpEndpoint }): string { + let pathSuffix = endpoint.fullPath.head === "/" ? "" : endpoint.fullPath.head; + for (const part of endpoint.fullPath.parts) { + if (part.pathParameter) { + pathSuffix += "%v"; + } + pathSuffix += part.tail; + } + return pathSuffix.replace(/^\/+/, ""); + } + + private getAcceptHeaderValue({ endpoint }: { endpoint: HttpEndpoint }): string | undefined { + const responseBody = endpoint.response?.body; + if (responseBody == null) { + return undefined; + } + switch (responseBody.type) { + case "streaming": + return this.getAcceptHeaderValueForStreaming({ streamingResponse: responseBody.value }); + case "bytes": + case "fileDownload": + case "json": + case "streamParameter": + case "text": + return undefined; + default: + assertNever(responseBody); + } + } + + private getAcceptHeaderValueForStreaming({ + streamingResponse + }: { + streamingResponse: StreamingResponse; + }): string | undefined { + switch (streamingResponse.type) { + case "sse": + return "text/event-stream"; + case "json": + case "text": + return undefined; + default: + assertNever(streamingResponse); + } + } + + private getContentTypeHeaderValue({ endpoint }: { endpoint: HttpEndpoint }): string | undefined { + const sdkRequest = endpoint.sdkRequest; + if (sdkRequest == null) { + return undefined; + } + switch (sdkRequest.shape.type) { + case "justRequestBody": + return this.getContentTypeHeaderValueForJustRequestBody({ justRequestBody: sdkRequest.shape.value }); + case "wrapper": { + const requestBody = endpoint.requestBody; + if (requestBody == null) { + return undefined; + } + return this.getContentTypeHeaderValueForWrapper({ wrapper: sdkRequest.shape, requestBody }); + } + default: + assertNever(sdkRequest.shape); + } + } + + private getContentTypeHeaderValueForJustRequestBody({ + justRequestBody + }: { + justRequestBody: SdkRequestBodyType; + }): string | undefined { + switch (justRequestBody.type) { + case "bytes": + return justRequestBody.contentType ?? EndpointGenerator.OCTET_STREAM_CONTENT_TYPE; + case "typeReference": + return justRequestBody.contentType; + default: + assertNever(justRequestBody); + } + } + + private getContentTypeHeaderValueForWrapper({ + wrapper, + requestBody + }: { + wrapper: SdkRequestWrapper; + requestBody: HttpRequestBody; + }): string | undefined { + switch (requestBody.type) { + case "bytes": + return requestBody.contentType ?? EndpointGenerator.OCTET_STREAM_CONTENT_TYPE; + case "fileUpload": + case "inlinedRequestBody": + case "reference": + return requestBody.contentType; + default: + assertNever(requestBody); + } + } + + private addHeaderValue({ wireValue, value }: { wireValue: string; value: go.AstNode }): go.CodeBlock { + return go.codeblock(`headers.Add("${wireValue}", fmt.Sprintf("%v", ${value}))`); + } + + private setHeaderValue({ wireValue, value }: { wireValue: string; value: string }): go.CodeBlock { + return go.codeblock(`headers.Add("${wireValue}", "${value}")`); + } + + private writeRawReturnZeroValueWithError(): go.CodeBlock { + return go.codeblock("return nil, err"); + } + + private writeReturnZeroValueWithError({ zeroValue }: { zeroValue?: go.TypeInstantiation }): go.CodeBlock { + return go.codeblock((writer) => { + writer.write(`return `); + if (zeroValue != null) { + writer.writeNode(zeroValue); + writer.write(", "); + } + writer.write("err"); + }); + } + + private getRawClientReceiverCodeBlock(): go.AstNode { + return go.codeblock(this.getRawClientReceiver()); + } + + private getCallerFieldReference(): go.AstNode { + return go.selector({ + on: this.getRawClientReceiverCodeBlock(), + selector: go.codeblock("caller") + }); + } + + private getRequestParameterName({ endpoint }: { endpoint: HttpEndpoint }): string { + const requestParameterName = endpoint.sdkRequest?.requestParameterName; + if (requestParameterName == null) { + return "request"; + } + return this.context.getParameterName(requestParameterName); + } + + private getRawClientReceiver(): string { + return "r"; + } +} diff --git a/generators/go-v2/sdk/src/endpoint/request/BytesRequest.ts b/generators/go-v2/sdk/src/endpoint/request/BytesRequest.ts new file mode 100644 index 00000000000..1e98d07ee05 --- /dev/null +++ b/generators/go-v2/sdk/src/endpoint/request/BytesRequest.ts @@ -0,0 +1,44 @@ +import { go } from "@fern-api/go-ast"; + +import { HttpEndpoint, HttpService, SdkRequest, TypeReference } from "@fern-fern/ir-sdk/api"; + +import { SdkGeneratorContext } from "../../SdkGeneratorContext"; +import { + EndpointRequest, + HeaderParameterCodeBlock, + QueryParameterCodeBlock, + RequestBodyCodeBlock +} from "./EndpointRequest"; + +export class BytesRequest extends EndpointRequest { + public constructor( + context: SdkGeneratorContext, + sdkRequest: SdkRequest, + service: HttpService, + endpoint: HttpEndpoint + ) { + super(context, sdkRequest, service, endpoint); + } + + public getRequestParameterType(): go.Type { + return go.Type.bytes(); + } + + public getRequestReference(): go.AstNode { + return go.codeblock("requestBuffer"); + } + + public getQueryParameterCodeBlock(): QueryParameterCodeBlock | undefined { + return undefined; + } + + public getHeaderParameterCodeBlock(): HeaderParameterCodeBlock | undefined { + return undefined; + } + + public getRequestBodyCodeBlock(): RequestBodyCodeBlock | undefined { + return { + requestBodyReference: go.codeblock(this.getRequestParameterName()) + }; + } +} diff --git a/generators/go-v2/sdk/src/endpoint/request/EndpointRequest.ts b/generators/go-v2/sdk/src/endpoint/request/EndpointRequest.ts new file mode 100644 index 00000000000..6ae560cea8c --- /dev/null +++ b/generators/go-v2/sdk/src/endpoint/request/EndpointRequest.ts @@ -0,0 +1,39 @@ +import { go } from "@fern-api/go-ast"; + +import { HttpEndpoint, HttpService, SdkRequest } from "@fern-fern/ir-sdk/api"; + +import { SdkGeneratorContext } from "../../SdkGeneratorContext"; + +export interface QueryParameterCodeBlock { + code: go.CodeBlock; + queryParameterBagReference: string; +} + +export interface HeaderParameterCodeBlock { + code: go.CodeBlock; + headerParameterBagReference: string; +} + +export interface RequestBodyCodeBlock { + code?: go.CodeBlock; + requestBodyReference: go.CodeBlock; +} + +export abstract class EndpointRequest { + public constructor( + protected readonly context: SdkGeneratorContext, + protected readonly sdkRequest: SdkRequest, + protected readonly service: HttpService, + protected readonly endpoint: HttpEndpoint + ) {} + + public abstract getRequestParameterType(): go.Type; + public abstract getRequestReference(): go.AstNode; + public abstract getQueryParameterCodeBlock(): QueryParameterCodeBlock | undefined; + public abstract getHeaderParameterCodeBlock(): HeaderParameterCodeBlock | undefined; + public abstract getRequestBodyCodeBlock(): RequestBodyCodeBlock | undefined; + + public getRequestParameterName(): string { + return this.context.getParameterName(this.sdkRequest.requestParameterName); + } +} diff --git a/generators/go-v2/sdk/src/endpoint/request/ReferencedEndpointRequest.ts b/generators/go-v2/sdk/src/endpoint/request/ReferencedEndpointRequest.ts new file mode 100644 index 00000000000..9f213a85b21 --- /dev/null +++ b/generators/go-v2/sdk/src/endpoint/request/ReferencedEndpointRequest.ts @@ -0,0 +1,48 @@ +import { go } from "@fern-api/go-ast"; + +import { HttpEndpoint, HttpService, SdkRequest, TypeReference } from "@fern-fern/ir-sdk/api"; + +import { SdkGeneratorContext } from "../../SdkGeneratorContext"; +import { + EndpointRequest, + HeaderParameterCodeBlock, + QueryParameterCodeBlock, + RequestBodyCodeBlock +} from "./EndpointRequest"; + +export class ReferencedEndpointRequest extends EndpointRequest { + private requestBodyShape: TypeReference; + + public constructor( + context: SdkGeneratorContext, + sdkRequest: SdkRequest, + service: HttpService, + endpoint: HttpEndpoint, + requestBodyShape: TypeReference + ) { + super(context, sdkRequest, service, endpoint); + this.requestBodyShape = requestBodyShape; + } + + public getRequestParameterType(): go.Type { + return this.context.goTypeMapper.convert({ reference: this.requestBodyShape }); + } + + public getRequestReference(): go.AstNode { + return go.codeblock(this.getRequestParameterName()); + } + + public getQueryParameterCodeBlock(): QueryParameterCodeBlock | undefined { + return undefined; + } + + public getHeaderParameterCodeBlock(): HeaderParameterCodeBlock | undefined { + return undefined; + } + + public getRequestBodyCodeBlock(): RequestBodyCodeBlock | undefined { + return { + requestBodyReference: go.codeblock(this.getRequestParameterName()) + }; + } +} diff --git a/generators/go-v2/sdk/src/endpoint/request/WrappedEndpointRequest.ts b/generators/go-v2/sdk/src/endpoint/request/WrappedEndpointRequest.ts new file mode 100644 index 00000000000..4096d92ed10 --- /dev/null +++ b/generators/go-v2/sdk/src/endpoint/request/WrappedEndpointRequest.ts @@ -0,0 +1,67 @@ +import { go } from "@fern-api/go-ast"; + +import { + HttpEndpoint, + HttpService, + Name, + SdkRequest, + SdkRequestWrapper, + ServiceId, + TypeReference +} from "@fern-fern/ir-sdk/api"; + +import { SdkGeneratorContext } from "../../SdkGeneratorContext"; +import { + EndpointRequest, + HeaderParameterCodeBlock, + QueryParameterCodeBlock, + RequestBodyCodeBlock +} from "./EndpointRequest"; + +export declare namespace WrappedEndpointRequest { + interface Args { + context: SdkGeneratorContext; + serviceId: ServiceId; + sdkRequest: SdkRequest; + wrapper: SdkRequestWrapper; + service: HttpService; + endpoint: HttpEndpoint; + } +} + +export class WrappedEndpointRequest extends EndpointRequest { + private serviceId: ServiceId; + private wrapper: SdkRequestWrapper; + + public constructor({ context, sdkRequest, serviceId, wrapper, service, endpoint }: WrappedEndpointRequest.Args) { + super(context, sdkRequest, service, endpoint); + this.serviceId = serviceId; + this.wrapper = wrapper; + } + + public getRequestParameterType(): go.Type { + return go.Type.pointer( + go.Type.reference(this.context.getRequestWrapperTypeReference(this.serviceId, this.wrapper.wrapperName)) + ); + } + + public getRequestReference(): go.AstNode { + return go.codeblock(this.getRequestParameterName()); + } + + public getQueryParameterCodeBlock(): QueryParameterCodeBlock | undefined { + // TODO: Implement this. + return undefined; + } + + public getHeaderParameterCodeBlock(): HeaderParameterCodeBlock | undefined { + // TODO: Implement this. + return undefined; + } + + public getRequestBodyCodeBlock(): RequestBodyCodeBlock | undefined { + return { + requestBodyReference: go.codeblock(this.getRequestParameterName()) + }; + } +} diff --git a/generators/go-v2/sdk/src/endpoint/utils/getEndpointRequest.ts b/generators/go-v2/sdk/src/endpoint/utils/getEndpointRequest.ts new file mode 100644 index 00000000000..5d9d7f4e038 --- /dev/null +++ b/generators/go-v2/sdk/src/endpoint/utils/getEndpointRequest.ts @@ -0,0 +1,76 @@ +import { assertNever } from "@fern-api/core-utils"; + +import { HttpEndpoint, HttpService, SdkRequest, ServiceId } from "@fern-fern/ir-sdk/api"; + +import { SdkGeneratorContext } from "../../SdkGeneratorContext"; +import { BytesRequest } from "../request/BytesRequest"; +import { EndpointRequest } from "../request/EndpointRequest"; +import { ReferencedEndpointRequest } from "../request/ReferencedEndpointRequest"; +import { WrappedEndpointRequest } from "../request/WrappedEndpointRequest"; + +export function getEndpointRequest({ + context, + endpoint, + serviceId, + service +}: { + context: SdkGeneratorContext; + endpoint: HttpEndpoint; + serviceId: ServiceId; + service: HttpService; +}): EndpointRequest | undefined { + if (endpoint.sdkRequest == null) { + return undefined; + } + if (endpoint.sdkRequest.shape.type === "wrapper") { + if (context.shouldSkipWrappedRequest({ endpoint, wrapper: endpoint.sdkRequest.shape })) { + return undefined; + } + } + return createEndpointRequest({ + context, + endpoint, + serviceId, + service, + sdkRequest: endpoint.sdkRequest + }); +} + +function createEndpointRequest({ + context, + sdkRequest, + endpoint, + service, + serviceId +}: { + context: SdkGeneratorContext; + sdkRequest: SdkRequest; + endpoint: HttpEndpoint; + service: HttpService; + serviceId: ServiceId; +}): EndpointRequest | undefined { + switch (sdkRequest.shape.type) { + case "wrapper": + return new WrappedEndpointRequest({ + context, + serviceId, + sdkRequest, + wrapper: sdkRequest.shape, + service, + endpoint + }); + case "justRequestBody": + if (sdkRequest.shape.value.type === "bytes") { + return new BytesRequest(context, sdkRequest, service, endpoint); + } + return new ReferencedEndpointRequest( + context, + sdkRequest, + service, + endpoint, + sdkRequest.shape.value.requestBodyType + ); + default: + assertNever(sdkRequest.shape); + } +} diff --git a/generators/go-v2/sdk/src/endpoint/utils/getEndpointReturnTypes.ts b/generators/go-v2/sdk/src/endpoint/utils/getEndpointReturnTypes.ts new file mode 100644 index 00000000000..eb9a006acc8 --- /dev/null +++ b/generators/go-v2/sdk/src/endpoint/utils/getEndpointReturnTypes.ts @@ -0,0 +1,36 @@ +import { assertNever } from "@fern-api/core-utils"; +import { go } from "@fern-api/go-ast"; + +import { HttpEndpoint } from "@fern-fern/ir-sdk/api"; + +import { SdkGeneratorContext } from "../../SdkGeneratorContext"; + +export function getEndpointReturnTypes({ + context, + endpoint +}: { + context: SdkGeneratorContext; + endpoint: HttpEndpoint; +}): go.Type | undefined { + const response = endpoint.response; + if (response?.body == null) { + return undefined; + } + const body = response.body; + switch (body.type) { + case "bytes": + return go.Type.bytes(); + case "streamParameter": + return go.Type.any(); + case "fileDownload": + return go.Type.reference(context.getIoReaderTypeReference()); + case "json": + return context.goTypeMapper.convert({ reference: body.value.responseBodyType }); + case "streaming": + return go.Type.reference(context.getStreamTypeReference(context.getStreamPayload(body.value))); + case "text": + return go.Type.string(); + default: + assertNever(body); + } +} diff --git a/generators/go-v2/sdk/src/endpoint/utils/getEndpointReturnZeroValue.ts b/generators/go-v2/sdk/src/endpoint/utils/getEndpointReturnZeroValue.ts new file mode 100644 index 00000000000..499777a03d4 --- /dev/null +++ b/generators/go-v2/sdk/src/endpoint/utils/getEndpointReturnZeroValue.ts @@ -0,0 +1,33 @@ +import { assertNever } from "@fern-api/core-utils"; +import { go } from "@fern-api/go-ast"; + +import { HttpEndpoint } from "@fern-fern/ir-sdk/api"; + +import { SdkGeneratorContext } from "../../SdkGeneratorContext"; + +export function getEndpointReturnZeroValues({ + context, + endpoint +}: { + context: SdkGeneratorContext; + endpoint: HttpEndpoint; +}): go.TypeInstantiation | undefined { + const response = endpoint.response; + if (response?.body == null) { + return undefined; + } + const body = response.body; + switch (body.type) { + case "json": + return context.goZeroValueMapper.convert({ reference: body.value.responseBodyType }); + case "text": + return go.TypeInstantiation.string(""); + case "bytes": + case "streamParameter": + case "fileDownload": + case "streaming": + return go.TypeInstantiation.nil(); + default: + assertNever(body); + } +} diff --git a/generators/go-v2/sdk/src/endpoint/utils/getRawEndpointReturnTypeReference.ts b/generators/go-v2/sdk/src/endpoint/utils/getRawEndpointReturnTypeReference.ts new file mode 100644 index 00000000000..686b3fbf5e7 --- /dev/null +++ b/generators/go-v2/sdk/src/endpoint/utils/getRawEndpointReturnTypeReference.ts @@ -0,0 +1,55 @@ +import { assertNever } from "@fern-api/core-utils"; +import { go } from "@fern-api/go-ast"; + +import { HttpEndpoint } from "@fern-fern/ir-sdk/api"; + +import { SdkGeneratorContext } from "../../SdkGeneratorContext"; + +export function getRawEndpointReturnTypeReference({ + context, + endpoint +}: { + context: SdkGeneratorContext; + endpoint: HttpEndpoint; +}): go.TypeReference { + const response = endpoint.response; + if (response?.body == null) { + return wrapWithRawResponseType({ context, returnType: go.Type.any() }); + } + const body = response.body; + switch (body.type) { + case "bytes": + return wrapWithRawResponseType({ context, returnType: go.Type.bytes() }); + case "fileDownload": + return wrapWithRawResponseType({ + context, + returnType: go.Type.reference(context.getIoReaderTypeReference()) + }); + case "json": + return wrapWithRawResponseType({ + context, + returnType: context.goTypeMapper.convert({ reference: body.value.responseBodyType }) + }); + case "streaming": + return wrapWithRawResponseType({ + context, + returnType: go.Type.reference(context.getStreamTypeReference(context.getStreamPayload(body.value))) + }); + case "streamParameter": + return context.getRawResponseTypeReference(go.Type.any()); + case "text": + return wrapWithRawResponseType({ context, returnType: go.Type.string() }); + default: + assertNever(body); + } +} + +function wrapWithRawResponseType({ + context, + returnType +}: { + context: SdkGeneratorContext; + returnType: go.Type; +}): go.TypeReference { + return context.getRawResponseTypeReference(returnType); +} diff --git a/generators/go-v2/sdk/src/internal/Caller.ts b/generators/go-v2/sdk/src/internal/Caller.ts new file mode 100644 index 00000000000..2ced0c5b848 --- /dev/null +++ b/generators/go-v2/sdk/src/internal/Caller.ts @@ -0,0 +1,174 @@ +import { go } from "@fern-api/go-ast"; + +import { HttpEndpoint } from "@fern-fern/ir-sdk/api"; + +import { SdkGeneratorContext } from "../SdkGeneratorContext"; + +export declare namespace Caller { + export interface CallArgs { + endpoint: HttpEndpoint; + clientReference: go.AstNode; + optionsReference: go.AstNode; + url: go.AstNode; + request?: go.AstNode; + response?: go.AstNode; + } +} + +/** + * Utility class that helps make HTTP calls. + */ +export class Caller { + public static TYPE_NAME = "Caller"; + public static FIELD_NAME = "caller"; + public static CONSTRUCTOR_FUNC_NAME = "NewCaller"; + public static CALLER_PARAMS_TYPE_NAME = "CallerParams"; + public static CALL_PARAMS_TYPE_NAME = "CallParams"; + public static CALL_METHOD_NAME = "Call"; + + private context: SdkGeneratorContext; + + public constructor(context: SdkGeneratorContext) { + this.context = context; + } + + public getTypeReference(): go.TypeReference { + return go.typeReference({ + name: Caller.TYPE_NAME, + importPath: this.context.getInternalImportPath() + }); + } + + public getConstructorTypeReference(): go.TypeReference { + return go.typeReference({ + name: Caller.CONSTRUCTOR_FUNC_NAME, + importPath: this.context.getInternalImportPath() + }); + } + + public getCallerParamsTypeReference(): go.TypeReference { + return go.typeReference({ + name: Caller.CALLER_PARAMS_TYPE_NAME, + importPath: this.context.getInternalImportPath() + }); + } + + public getCallParamsTypeReference(): go.TypeReference { + return go.typeReference({ + name: Caller.CALL_PARAMS_TYPE_NAME, + importPath: this.context.getInternalImportPath() + }); + } + + public getFieldName(): string { + return Caller.FIELD_NAME; + } + + public getField(): go.Field { + return go.field({ + name: this.getFieldName(), + type: go.Type.pointer(go.Type.reference(this.getTypeReference())) + }); + } + + public instantiate({ client, maxAttempts }: { client: go.AstNode; maxAttempts: go.AstNode }): go.AstNode { + return go.invokeFunc({ + func: this.getConstructorTypeReference(), + arguments_: [ + go.TypeInstantiation.structPointer({ + typeReference: this.getCallerParamsTypeReference(), + fields: [ + { + name: "Client", + value: go.TypeInstantiation.reference(client) + }, + { + name: "MaxAttempts", + value: go.TypeInstantiation.reference(maxAttempts) + } + ] + }) + ] + }); + } + + public call(args: Caller.CallArgs): go.AstNode { + const arguments_: go.StructField[] = [ + { + name: "URL", + value: go.TypeInstantiation.reference(args.url) + }, + { + name: "Method", + value: go.TypeInstantiation.reference(this.context.getNetHttpMethodTypeReference(args.endpoint.method)) + }, + { + name: "Headers", + value: go.TypeInstantiation.reference(go.codeblock("headers")) + }, + { + name: "MaxAttempts", + value: go.TypeInstantiation.reference( + go.selector({ + on: args.optionsReference, + selector: go.codeblock("MaxAttempts") + }) + ) + }, + { + name: "BodyProperties", + value: go.TypeInstantiation.reference( + go.selector({ + on: args.optionsReference, + selector: go.codeblock("BodyProperties") + }) + ) + }, + { + name: "QueryParameters", + value: go.TypeInstantiation.reference( + go.selector({ + on: args.optionsReference, + selector: go.codeblock("QueryParameters") + }) + ) + }, + { + name: "Client", + value: go.TypeInstantiation.reference( + go.selector({ + on: args.optionsReference, + selector: go.codeblock("HTTPClient") + }) + ) + } + ]; + if (args.request != null) { + arguments_.push({ + name: "Request", + value: go.TypeInstantiation.reference(args.request) + }); + } + if (args.response != null) { + arguments_.push({ + name: "Response", + value: go.TypeInstantiation.reference(args.response) + }); + } + return go.codeblock((writer) => { + writer.writeNode( + go.invokeMethod({ + on: args.clientReference, + method: Caller.CALL_METHOD_NAME, + arguments_: [ + this.context.getContextParameterReference(), + go.TypeInstantiation.structPointer({ + typeReference: this.getCallParamsTypeReference(), + fields: arguments_ + }) + ] + }) + ); + }); + } +} diff --git a/generators/go-v2/sdk/src/module/ModuleConfig.ts b/generators/go-v2/sdk/src/module/ModuleConfig.ts new file mode 100644 index 00000000000..e4bc22a0b5c --- /dev/null +++ b/generators/go-v2/sdk/src/module/ModuleConfig.ts @@ -0,0 +1,19 @@ +export namespace ModuleConfig { + export const FILENAME = "go.mod"; + + export const DEFAULT: ModuleConfig = { + path: "sdk", + version: "1.18", + imports: { + "github.com/google/uuid": "v1.4.0", + "github.com/testify/stretchr": "v1.7.0", + "gopkg.in/yaml.v3": "v3.0.1" + } + }; +} + +export interface ModuleConfig { + path: string; + version?: string; + imports?: Record; +} diff --git a/generators/go-v2/sdk/src/module/ModuleConfigWriter.ts b/generators/go-v2/sdk/src/module/ModuleConfigWriter.ts new file mode 100644 index 00000000000..dd4c2e5d47e --- /dev/null +++ b/generators/go-v2/sdk/src/module/ModuleConfigWriter.ts @@ -0,0 +1,59 @@ +import dedent from "dedent"; + +import { File } from "@fern-api/base-generator"; +import { RelativeFilePath } from "@fern-api/fs-utils"; +import { FileGenerator } from "@fern-api/go-base"; + +import { SdkCustomConfigSchema } from "../SdkCustomConfig"; +import { SdkGeneratorContext } from "../SdkGeneratorContext"; +import { ModuleConfig } from "./ModuleConfig"; + +export class ModuleConfigWriter extends FileGenerator { + private moduleConfig: ModuleConfig; + + public constructor({ context, moduleConfig }: { context: SdkGeneratorContext; moduleConfig: ModuleConfig }) { + super(context); + this.moduleConfig = moduleConfig; + } + + public doGenerate(): File { + return new File(this.getFilepath(), this.getDirectory(), this.getContent()); + } + + protected getFilepath(): RelativeFilePath { + return RelativeFilePath.of(ModuleConfig.FILENAME); + } + + private getDirectory(): RelativeFilePath { + return RelativeFilePath.of("."); + } + + private getContent(): string { + return dedent` + ${this.writeModulePath()} + ${this.writeGoVersion()} + ${this.writeImports()} + `; + } + + private writeModulePath(): string { + return `module ${this.moduleConfig.path}`; + } + + private writeGoVersion(): string { + return `go ${this.moduleConfig.version ?? ModuleConfig.DEFAULT.version}`; + } + + private writeImports(): string { + if (this.moduleConfig.imports == null) { + return ""; + } + return Object.entries(this.moduleConfig.imports) + .map(([importPath, version]) => this.writeImport({ importPath, version })) + .join("\n"); + } + + private writeImport({ importPath, version }: { importPath: string; version: string }): string { + return `require ${importPath} ${version}`; + } +} diff --git a/generators/go-v2/sdk/src/raw-client/RawClientGenerator.ts b/generators/go-v2/sdk/src/raw-client/RawClientGenerator.ts new file mode 100644 index 00000000000..f92e781c024 --- /dev/null +++ b/generators/go-v2/sdk/src/raw-client/RawClientGenerator.ts @@ -0,0 +1,159 @@ +import { RelativeFilePath, join } from "@fern-api/fs-utils"; +import { go } from "@fern-api/go-ast"; +import { FileGenerator, GoFile } from "@fern-api/go-base"; + +import { HttpService, ServiceId, Subpackage } from "@fern-fern/ir-sdk/api"; + +import { SdkCustomConfigSchema } from "../SdkCustomConfig"; +import { SdkGeneratorContext } from "../SdkGeneratorContext"; + +export declare namespace RawClientGenerator { + interface Args { + context: SdkGeneratorContext; + subpackage: Subpackage; + serviceId: ServiceId; + service: HttpService; + } +} + +export class RawClientGenerator extends FileGenerator { + private subpackage: Subpackage | undefined; + private serviceId: ServiceId; + private service: HttpService; + + constructor({ subpackage, context, serviceId, service }: RawClientGenerator.Args) { + super(context); + this.subpackage = subpackage; + this.serviceId = serviceId; + this.service = service; + } + + public doGenerate(): GoFile { + const struct = go.struct({ + ...this.getClassReference() + }); + + struct.addConstructor(this.getConstructor()); + + struct.addField( + go.field({ + name: "baseURL", + type: go.Type.string() + }), + this.context.caller.getField(), + go.field({ + name: "header", + type: go.Type.reference(this.context.getNetHttpHeaderTypeReference()) + }) + ); + + for (const endpoint of this.service.endpoints) { + const methods = this.context.endpointGenerator.generateRaw({ + serviceId: this.serviceId, + service: this.service, + subpackage: this.subpackage, + endpoint + }); + struct.addMethod(...methods); + } + + return new GoFile({ + node: struct, + rootImportPath: this.context.getRootImportPath(), + packageName: this.getPackageName(), + importPath: this.getImportPath(), + directory: this.getDirectory(), + filename: this.context.getRawClientFilename(), + customConfig: this.context.customConfig + }); + } + + protected getFilepath(): RelativeFilePath { + return join(this.getDirectory(), RelativeFilePath.of(this.context.getRawClientFilename())); + } + + private getConstructor(): go.Struct.Constructor { + return { + parameters: [ + go.parameter({ + name: "opts", + type: this.context.getVariadicRequestOptionType() + }) + ], + body: go.codeblock((writer) => { + writer.write("options := "); + writer.writeNode(this.context.callNewRequestOptions(go.codeblock("opts..."))); + writer.newLine(); + writer.write("return "); + writer.writeNode( + go.TypeInstantiation.structPointer({ + typeReference: this.getClassReference(), + fields: [ + { + name: "baseURL", + value: go.TypeInstantiation.reference( + go.selector({ + on: go.codeblock("options"), + selector: go.codeblock("BaseURL") + }) + ) + }, + { + name: "caller", + value: go.TypeInstantiation.reference( + this.context.caller.instantiate({ + client: go.TypeInstantiation.reference( + go.selector({ + on: go.codeblock("options"), + selector: go.codeblock("HTTPClient") + }) + ), + maxAttempts: go.TypeInstantiation.reference( + go.selector({ + on: go.codeblock("options"), + selector: go.codeblock("MaxAttempts") + }) + ) + }) + ) + }, + { + name: "header", + value: go.TypeInstantiation.reference( + go.selector({ + on: go.codeblock("options"), + selector: go.codeblock("ToHeader()") + }) + ) + } + ] + }) + ); + }) + }; + } + + private getClassReference(): go.TypeReference { + return this.subpackage != null + ? this.context.getSubpackageRawClientClassReference(this.subpackage) + : this.context.getRootRawClientClassReference(); + } + + private getPackageName(): string { + return this.subpackage != null + ? this.context.getSubpackageClientPackageName(this.subpackage) + : this.context.getClientPackageName(); + } + + private getDirectory(): RelativeFilePath { + return this.subpackage != null + ? this.context.getSubpackageClientFileLocation(this.subpackage).directory + : this.context.getRootClientDirectory(); + } + + private getImportPath(): string { + return this.subpackage != null + ? this.context.getSubpackageClientClassReference(this.subpackage).importPath + : this.context.getRootClientImportPath(); + } +} diff --git a/generators/go-v2/sdk/src/subpackage-client/SubPackageClientGenerator.ts b/generators/go-v2/sdk/src/subpackage-client/SubPackageClientGenerator.ts new file mode 100644 index 00000000000..c7e6794b80e --- /dev/null +++ b/generators/go-v2/sdk/src/subpackage-client/SubPackageClientGenerator.ts @@ -0,0 +1,82 @@ +import { RelativeFilePath, join } from "@fern-api/fs-utils"; +import { go } from "@fern-api/go-ast"; +import { FileGenerator, GoFile } from "@fern-api/go-base"; + +import { HttpService, ServiceId, Subpackage } from "@fern-fern/ir-sdk/api"; + +import { SdkCustomConfigSchema } from "../SdkCustomConfig"; +import { SdkGeneratorContext } from "../SdkGeneratorContext"; + +export declare namespace SubClientGenerator { + interface Args { + context: SdkGeneratorContext; + subpackage: Subpackage; + serviceId?: ServiceId; + service?: HttpService; + } +} + +export class SubPackageClientGenerator extends FileGenerator { + private classReference: go.TypeReference; + private subpackage: Subpackage; + private serviceId: ServiceId | undefined; + private service: HttpService | undefined; + + constructor({ subpackage, context, serviceId, service }: SubClientGenerator.Args) { + super(context); + this.classReference = this.context.getSubpackageClientClassReference(subpackage); + this.subpackage = subpackage; + this.serviceId = serviceId; + this.service = service; + } + + public doGenerate(): GoFile { + const struct = go.struct({ + ...this.classReference + }); + + const subpackages = this.getSubpackages(); + for (const subpackage of subpackages) { + struct.addField(this.context.getSubpackageClientField(subpackage)); + } + + if (this.service != null && this.serviceId != null) { + for (const endpoint of this.service.endpoints) { + const methods = this.context.endpointGenerator.generate({ + serviceId: this.serviceId, + service: this.service, + subpackage: this.subpackage, + endpoint + }); + for (const method of methods) { + struct.addMethod(method); + } + } + } + + return new GoFile({ + node: struct, + rootImportPath: this.context.getRootImportPath(), + packageName: this.context.getClientPackageName(), + importPath: this.context.getSubpackageClientClassReference(this.subpackage).importPath, + directory: this.context.getSubpackageClientFileLocation(this.subpackage).directory, + filename: this.context.getClientFilename(), + customConfig: this.context.customConfig + }); + } + + private getSubpackages(): Subpackage[] { + return this.subpackage.subpackages + .map((subpackageId) => { + return this.context.getSubpackageOrThrow(subpackageId); + }) + .filter((subpackage) => this.context.shouldGenerateSubpackageClient(subpackage)); + } + + protected getFilepath(): RelativeFilePath { + return join( + this.context.getSubpackageClientFileLocation(this.subpackage).directory, + RelativeFilePath.of(this.context.getClientFilename()) + ); + } +} diff --git a/generators/go/internal/generator/generator.go b/generators/go/internal/generator/generator.go index 32b7a75302b..2c6df4ca079 100644 --- a/generators/go/internal/generator/generator.go +++ b/generators/go/internal/generator/generator.go @@ -736,8 +736,7 @@ func (g *Generator) generate(ir *fernir.IntermediateRepresentation, mode Mode) ( // The go.sum file will be generated after the // go.mod file is written to disk. if g.config.ModuleConfig != nil { - requiresGenerics := g.config.EnableExplicitNull || ir.SdkConfig.HasStreamingEndpoints || generatedPagination - file, generatedGoVersion, err := NewModFile(g.coordinator, g.config.ModuleConfig, requiresGenerics) + file, generatedGoVersion, err := NewModFile(g.coordinator, g.config.ModuleConfig) if err != nil { return nil, err } diff --git a/generators/go/internal/generator/mod_file.go b/generators/go/internal/generator/mod_file.go index 646e2130ba5..af5889bdf72 100644 --- a/generators/go/internal/generator/mod_file.go +++ b/generators/go/internal/generator/mod_file.go @@ -9,13 +9,9 @@ import ( const ( // minimumGoVersion specifies the minimum Go version required to - // use this library. We require at least 1.13, which is when - // modules were officially introduced. - minimumGoVersion = "1.13" - - // minimumGoGenericsVersion specifies the minimum Go version if - // the user requires generics (i.e. *Optional[T] or *Stream[T]). - minimumGoGenericsVersion = "1.18" + // use this library. We require at least 1.18, which is when + // generics were officially introduced. + minimumGoVersion = "1.18" // modFilename is the default name of a Go module file. modFilename = "go.mod" @@ -27,10 +23,10 @@ const ( // // module github.com/fern-api/fern-go // -// go 1.13 +// go 1.18 // // require github.com/google/uuid v1.4.0 -func NewModFile(coordinator *coordinator.Client, c *ModuleConfig, requiresGenerics bool) (*File, string, error) { +func NewModFile(coordinator *coordinator.Client, c *ModuleConfig) (*File, string, error) { if c.Path == "" { return nil, "", fmt.Errorf("module path is required") } @@ -44,11 +40,7 @@ func NewModFile(coordinator *coordinator.Client, c *ModuleConfig, requiresGeneri // Write the go version. version := c.Version if version == "" { - if requiresGenerics { - version = minimumGoGenericsVersion - } else { - version = minimumGoVersion - } + version = minimumGoVersion } fmt.Fprintf(buffer, "go %s\n", version) fmt.Fprintln(buffer) diff --git a/generators/go/internal/generator/sdk.go b/generators/go/internal/generator/sdk.go index 765df85e127..0ed711c3ff5 100644 --- a/generators/go/internal/generator/sdk.go +++ b/generators/go/internal/generator/sdk.go @@ -1294,9 +1294,9 @@ func (f *fileWriter) WriteClient( if endpoint.Method == "http.MethodHead" { // HEAD requests don't have a response body, so we can simply return the raw // response headers. - f.P("response, err := ", receiver, ".caller.CallRaw(") + f.P("response, err := ", receiver, ".caller.Call(") f.P("ctx,") - f.P("&internal.CallRawParams{") + f.P("&internal.CallParams{") f.P("URL: endpointURL, ") f.P("Method:", endpoint.Method, ",") f.P("Headers:", headersParameter, ",") @@ -1315,7 +1315,6 @@ func (f *fileWriter) WriteClient( f.P("if err != nil {") f.P("return ", endpoint.ErrorReturnValues) f.P("}") - f.P("defer response.Body.Close()") f.P("return response.Header, nil") f.P("}") f.P() @@ -1488,7 +1487,7 @@ func (f *fileWriter) WriteClient( f.P("}") f.P() } else { - f.P("if err := ", receiver, ".caller.Call(") + f.P("if _, err := ", receiver, ".caller.Call(") f.P("ctx,") f.P("&internal.CallParams{") f.P("URL: endpointURL, ") diff --git a/generators/go/internal/generator/sdk/core/api_error.go b/generators/go/internal/generator/sdk/core/api_error.go index dc4190ca1cd..6168388541b 100644 --- a/generators/go/internal/generator/sdk/core/api_error.go +++ b/generators/go/internal/generator/sdk/core/api_error.go @@ -1,19 +1,24 @@ package core -import "fmt" +import ( + "fmt" + "net/http" +) // APIError is a lightweight wrapper around the standard error // interface that preserves the status code from the RPC, if any. type APIError struct { err error - StatusCode int `json:"-"` + StatusCode int `json:"-"` + Header http.Header `json:"-"` } // NewAPIError constructs a new API error. -func NewAPIError(statusCode int, err error) *APIError { +func NewAPIError(statusCode int, header http.Header, err error) *APIError { return &APIError{ err: err, + Header: header, StatusCode: statusCode, } } diff --git a/generators/go/internal/generator/sdk/core/http.go b/generators/go/internal/generator/sdk/core/http.go index b553350b84e..92c43569294 100644 --- a/generators/go/internal/generator/sdk/core/http.go +++ b/generators/go/internal/generator/sdk/core/http.go @@ -6,3 +6,10 @@ import "net/http" type HTTPClient interface { Do(*http.Request) (*http.Response, error) } + +// Response is an HTTP response from an HTTP client. +type Response[T any] struct { + StatusCode int + Header http.Header + Body T +} diff --git a/generators/go/internal/generator/sdk/internal/caller.go b/generators/go/internal/generator/sdk/internal/caller.go index c7edd25a8e5..a4b94ffafde 100644 --- a/generators/go/internal/generator/sdk/internal/caller.go +++ b/generators/go/internal/generator/sdk/internal/caller.go @@ -64,67 +64,14 @@ type CallParams struct { ErrorDecoder ErrorDecoder } -// Call issues an API call according to the given call parameters. -func (c *Caller) Call(ctx context.Context, params *CallParams) error { - resp, err := c.CallRaw( - ctx, - &CallRawParams{ - URL: params.URL, - Method: params.Method, - MaxAttempts: params.MaxAttempts, - Headers: params.Headers, - BodyProperties: params.BodyProperties, - QueryParameters: params.QueryParameters, - Client: params.Client, - Request: params.Request, - ErrorDecoder: params.ErrorDecoder, - }, - ) - if err != nil { - return err - } - - // Close the response body after we're done. - defer resp.Body.Close() - - if params.Response != nil { - if writer, ok := params.Response.(io.Writer); ok { - _, err = io.Copy(writer, resp.Body) - } else { - err = json.NewDecoder(resp.Body).Decode(params.Response) - } - if err != nil { - if err == io.EOF { - if params.ResponseIsOptional { - // The response is optional, so we should ignore the - // io.EOF error - return nil - } - return fmt.Errorf("expected a %T response, but the server responded with nothing", params.Response) - } - return err - } - } - - return nil -} - -// CallRawParams represents the parameters used to issue an API call. -type CallRawParams struct { - URL string - Method string - MaxAttempts uint - Headers http.Header - BodyProperties map[string]interface{} - QueryParameters url.Values - Client core.HTTPClient - Request interface{} - ErrorDecoder ErrorDecoder +// CallResponse is a parsed HTTP response from an API call. +type CallResponse struct { + StatusCode int + Header http.Header } -// CallRaw issues an API call according to the given call parameters and returns the raw HTTP response. -// The caller is responsible for closing the response body. -func (c *Caller) CallRaw(ctx context.Context, params *CallRawParams) (*http.Response, error) { +// Call issues an API call according to the given call parameters. +func (c *Caller) Call(ctx context.Context, params *CallParams) (*CallResponse, error) { url := buildURL(params.URL, params.QueryParameters) req, err := newRequest( ctx, @@ -145,6 +92,7 @@ func (c *Caller) CallRaw(ctx context.Context, params *CallRawParams) (*http.Resp client := c.client if params.Client != nil { + // Use the HTTP client scoped to the request. client = params.Client } @@ -163,19 +111,46 @@ func (c *Caller) CallRaw(ctx context.Context, params *CallRawParams) (*http.Resp return nil, err } + // Close the response body after we're done. + defer resp.Body.Close() + // Check if the call was cancelled before we return the error // associated with the call and/or unmarshal the response data. if err := ctx.Err(); err != nil { - defer resp.Body.Close() return nil, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { - defer resp.Body.Close() return nil, decodeError(resp, params.ErrorDecoder) } - return resp, nil + // Mutate the response parameter in-place. + if params.Response != nil { + if writer, ok := params.Response.(io.Writer); ok { + _, err = io.Copy(writer, resp.Body) + } else { + err = json.NewDecoder(resp.Body).Decode(params.Response) + } + if err != nil { + if err == io.EOF { + if params.ResponseIsOptional { + // The response is optional, so we should ignore the + // io.EOF error + return &CallResponse{ + StatusCode: resp.StatusCode, + Header: resp.Header, + }, nil + } + return nil, fmt.Errorf("expected a %T response, but the server responded with nothing", params.Response) + } + return nil, err + } + } + + return &CallResponse{ + StatusCode: resp.StatusCode, + Header: resp.Header, + }, nil } // buildURL constructs the final URL by appending the given query parameters (if any). @@ -250,7 +225,7 @@ func decodeError(response *http.Response, errorDecoder ErrorDecoder) error { // This endpoint has custom errors, so we'll // attempt to unmarshal the error into a structured // type based on the status code. - return errorDecoder(response.StatusCode, response.Body) + return errorDecoder(response.StatusCode, response.Header, response.Body) } // This endpoint doesn't have any custom error // types, so we just read the body as-is, and @@ -263,9 +238,9 @@ func decodeError(response *http.Response, errorDecoder ErrorDecoder) error { // The error didn't have a response body, // so all we can do is return an error // with the status code. - return core.NewAPIError(response.StatusCode, nil) + return core.NewAPIError(response.StatusCode, response.Header, nil) } - return core.NewAPIError(response.StatusCode, errors.New(string(bytes))) + return core.NewAPIError(response.StatusCode, response.Header, errors.New(string(bytes))) } // isNil is used to determine if the request value is equal to nil (i.e. an interface diff --git a/generators/go/internal/generator/sdk/internal/caller_test.go b/generators/go/internal/generator/sdk/internal/caller_test.go index 2ab3a2141ce..c5e364f7a2b 100644 --- a/generators/go/internal/generator/sdk/internal/caller_test.go +++ b/generators/go/internal/generator/sdk/internal/caller_test.go @@ -102,6 +102,7 @@ func TestCall(t *testing.T) { wantError: &NotFoundError{ APIError: core.NewAPIError( http.StatusNotFound, + http.Header{}, errors.New(`{"message":"ID \"404\" not found"}`), ), }, @@ -115,6 +116,7 @@ func TestCall(t *testing.T) { giveRequest: nil, wantError: core.NewAPIError( http.StatusBadRequest, + http.Header{}, errors.New("invalid request"), ), }, @@ -140,6 +142,7 @@ func TestCall(t *testing.T) { }, wantError: core.NewAPIError( http.StatusInternalServerError, + http.Header{}, errors.New("failed to process request"), ), }, @@ -212,7 +215,7 @@ func TestCall(t *testing.T) { }, ) var response *Response - err := caller.Call( + _, err := caller.Call( context.Background(), &CallParams{ URL: server.URL + test.givePathSuffix, @@ -236,67 +239,6 @@ func TestCall(t *testing.T) { } } -func TestCallRaw(t *testing.T) { - tests := []*TestCase{ - { - description: "HEAD success", - giveMethod: http.MethodHead, - giveHeader: http.Header{ - "X-API-Status": []string{"success"}, - }, - wantHeaders: http.Header{ - "Content-Length": []string{"250"}, - "Date": []string{"1970-01-01"}, - }, - }, - } - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { - server := httptest.NewServer( - http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, test.giveMethod, r.Method) - for header, value := range test.giveHeader { - assert.Equal(t, value, r.Header.Values(header)) - } - for header, values := range test.wantHeaders { - for _, value := range values { - w.Header().Add(header, value) - } - } - w.WriteHeader(http.StatusOK) - }, - ), - ) - defer server.Close() - - caller := NewCaller( - &CallerParams{ - Client: server.Client(), - }, - ) - response, err := caller.CallRaw( - context.Background(), - &CallRawParams{ - URL: server.URL + test.givePathSuffix, - Method: test.giveMethod, - Headers: test.giveHeader, - BodyProperties: test.giveBodyProperties, - QueryParameters: test.giveQueryParams, - Request: test.giveRequest, - ErrorDecoder: test.giveErrorDecoder, - }, - ) - if test.wantError != nil { - assert.EqualError(t, err, test.wantError.Error()) - return - } - require.NoError(t, err) - assert.Equal(t, test.wantHeaders, response.Header) - }) - } -} - func TestMergeHeaders(t *testing.T) { t.Run("both empty", func(t *testing.T) { merged := MergeHeaders(make(http.Header), make(http.Header)) @@ -432,13 +374,13 @@ func newTestServer(t *testing.T, tc *TestCase) *httptest.Server { } // newTestErrorDecoder returns an error decoder suitable for tests. -func newTestErrorDecoder(t *testing.T) func(int, io.Reader) error { - return func(statusCode int, body io.Reader) error { +func newTestErrorDecoder(t *testing.T) func(int, http.Header, io.Reader) error { + return func(statusCode int, header http.Header, body io.Reader) error { raw, err := io.ReadAll(body) require.NoError(t, err) var ( - apiError = core.NewAPIError(statusCode, errors.New(string(raw))) + apiError = core.NewAPIError(statusCode, header, errors.New(string(raw))) decoder = json.NewDecoder(bytes.NewReader(raw)) ) if statusCode == http.StatusNotFound { diff --git a/generators/go/internal/generator/sdk/internal/error_decoder.go b/generators/go/internal/generator/sdk/internal/error_decoder.go index a67415b080f..6a04e1b30eb 100644 --- a/generators/go/internal/generator/sdk/internal/error_decoder.go +++ b/generators/go/internal/generator/sdk/internal/error_decoder.go @@ -6,26 +6,28 @@ import ( "errors" "fmt" "io" + "net/http" "github.com/fern-api/fern-go/internal/generator/sdk/core" ) // ErrorDecoder decodes *http.Response errors and returns a // typed API error (e.g. *core.APIError). -type ErrorDecoder func(statusCode int, body io.Reader) error +type ErrorDecoder func(statusCode int, header http.Header, body io.Reader) error // ErrorCodes maps HTTP status codes to error constructors. type ErrorCodes map[int]func(*core.APIError) error // NewErrorDecoder returns a new ErrorDecoder backed by the given error codes. func NewErrorDecoder(errorCodes ErrorCodes) ErrorDecoder { - return func(statusCode int, body io.Reader) error { + return func(statusCode int, header http.Header, body io.Reader) error { raw, err := io.ReadAll(body) if err != nil { return fmt.Errorf("failed to read error from response body: %w", err) } apiError := core.NewAPIError( statusCode, + header, errors.New(string(raw)), ) newErrorFunc, ok := errorCodes[statusCode] diff --git a/generators/go/internal/generator/sdk/internal/error_decoder_test.go b/generators/go/internal/generator/sdk/internal/error_decoder_test.go index 850600447b4..c9c0415977d 100644 --- a/generators/go/internal/generator/sdk/internal/error_decoder_test.go +++ b/generators/go/internal/generator/sdk/internal/error_decoder_test.go @@ -21,35 +21,39 @@ func TestErrorDecoder(t *testing.T) { tests := []struct { description string giveStatusCode int + giveHeader http.Header giveBody string wantError error }{ { description: "unrecognized status code", giveStatusCode: http.StatusInternalServerError, + giveHeader: http.Header{}, giveBody: "Internal Server Error", - wantError: core.NewAPIError(http.StatusInternalServerError, errors.New("Internal Server Error")), + wantError: core.NewAPIError(http.StatusInternalServerError, http.Header{}, errors.New("Internal Server Error")), }, { description: "not found with valid JSON", giveStatusCode: http.StatusNotFound, + giveHeader: http.Header{}, giveBody: `{"message": "Resource not found"}`, wantError: &NotFoundError{ - APIError: core.NewAPIError(http.StatusNotFound, errors.New(`{"message": "Resource not found"}`)), + APIError: core.NewAPIError(http.StatusNotFound, http.Header{}, errors.New(`{"message": "Resource not found"}`)), Message: "Resource not found", }, }, { description: "not found with invalid JSON", giveStatusCode: http.StatusNotFound, + giveHeader: http.Header{}, giveBody: `Resource not found`, - wantError: core.NewAPIError(http.StatusNotFound, errors.New("Resource not found")), + wantError: core.NewAPIError(http.StatusNotFound, http.Header{}, errors.New("Resource not found")), }, } for _, tt := range tests { t.Run(tt.description, func(t *testing.T) { - assert.Equal(t, tt.wantError, decoder(tt.giveStatusCode, bytes.NewReader([]byte(tt.giveBody)))) + assert.Equal(t, tt.wantError, decoder(tt.giveStatusCode, tt.giveHeader, bytes.NewReader([]byte(tt.giveBody)))) }) } } diff --git a/generators/go/internal/generator/sdk/internal/retrier_test.go b/generators/go/internal/generator/sdk/internal/retrier_test.go index 79397451450..13e30076c44 100644 --- a/generators/go/internal/generator/sdk/internal/retrier_test.go +++ b/generators/go/internal/generator/sdk/internal/retrier_test.go @@ -107,7 +107,7 @@ func TestRetrier(t *testing.T) { ) var response *Response - err := caller.Call( + _, err := caller.Call( context.Background(), &CallParams{ URL: server.URL, diff --git a/seed/go-sdk/idempotency-headers/core/api_error.go b/seed/go-sdk/idempotency-headers/core/api_error.go index dc4190ca1cd..6168388541b 100644 --- a/seed/go-sdk/idempotency-headers/core/api_error.go +++ b/seed/go-sdk/idempotency-headers/core/api_error.go @@ -1,19 +1,24 @@ package core -import "fmt" +import ( + "fmt" + "net/http" +) // APIError is a lightweight wrapper around the standard error // interface that preserves the status code from the RPC, if any. type APIError struct { err error - StatusCode int `json:"-"` + StatusCode int `json:"-"` + Header http.Header `json:"-"` } // NewAPIError constructs a new API error. -func NewAPIError(statusCode int, err error) *APIError { +func NewAPIError(statusCode int, header http.Header, err error) *APIError { return &APIError{ err: err, + Header: header, StatusCode: statusCode, } } diff --git a/seed/go-sdk/idempotency-headers/core/http.go b/seed/go-sdk/idempotency-headers/core/http.go index b553350b84e..92c43569294 100644 --- a/seed/go-sdk/idempotency-headers/core/http.go +++ b/seed/go-sdk/idempotency-headers/core/http.go @@ -6,3 +6,10 @@ import "net/http" type HTTPClient interface { Do(*http.Request) (*http.Response, error) } + +// Response is an HTTP response from an HTTP client. +type Response[T any] struct { + StatusCode int + Header http.Header + Body T +} diff --git a/seed/go-sdk/idempotency-headers/go.mod b/seed/go-sdk/idempotency-headers/go.mod index 418883f3713..8e93ee2b0ca 100644 --- a/seed/go-sdk/idempotency-headers/go.mod +++ b/seed/go-sdk/idempotency-headers/go.mod @@ -1,9 +1,14 @@ module github.com/idempotency-headers/fern -go 1.13 +go 1.18 require ( github.com/google/uuid v1.4.0 github.com/stretchr/testify v1.7.0 +) + +require ( + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/seed/go-sdk/idempotency-headers/internal/caller.go b/seed/go-sdk/idempotency-headers/internal/caller.go index d5363fcd3ef..6cc9c680f1b 100644 --- a/seed/go-sdk/idempotency-headers/internal/caller.go +++ b/seed/go-sdk/idempotency-headers/internal/caller.go @@ -64,67 +64,14 @@ type CallParams struct { ErrorDecoder ErrorDecoder } -// Call issues an API call according to the given call parameters. -func (c *Caller) Call(ctx context.Context, params *CallParams) error { - resp, err := c.CallRaw( - ctx, - &CallRawParams{ - URL: params.URL, - Method: params.Method, - MaxAttempts: params.MaxAttempts, - Headers: params.Headers, - BodyProperties: params.BodyProperties, - QueryParameters: params.QueryParameters, - Client: params.Client, - Request: params.Request, - ErrorDecoder: params.ErrorDecoder, - }, - ) - if err != nil { - return err - } - - // Close the response body after we're done. - defer resp.Body.Close() - - if params.Response != nil { - if writer, ok := params.Response.(io.Writer); ok { - _, err = io.Copy(writer, resp.Body) - } else { - err = json.NewDecoder(resp.Body).Decode(params.Response) - } - if err != nil { - if err == io.EOF { - if params.ResponseIsOptional { - // The response is optional, so we should ignore the - // io.EOF error - return nil - } - return fmt.Errorf("expected a %T response, but the server responded with nothing", params.Response) - } - return err - } - } - - return nil -} - -// CallRawParams represents the parameters used to issue an API call. -type CallRawParams struct { - URL string - Method string - MaxAttempts uint - Headers http.Header - BodyProperties map[string]interface{} - QueryParameters url.Values - Client core.HTTPClient - Request interface{} - ErrorDecoder ErrorDecoder +// CallResponse is a parsed HTTP response from an API call. +type CallResponse struct { + StatusCode int + Header http.Header } -// CallRaw issues an API call according to the given call parameters and returns the raw HTTP response. -// The caller is responsible for closing the response body. -func (c *Caller) CallRaw(ctx context.Context, params *CallRawParams) (*http.Response, error) { +// Call issues an API call according to the given call parameters. +func (c *Caller) Call(ctx context.Context, params *CallParams) (*CallResponse, error) { url := buildURL(params.URL, params.QueryParameters) req, err := newRequest( ctx, @@ -145,6 +92,7 @@ func (c *Caller) CallRaw(ctx context.Context, params *CallRawParams) (*http.Resp client := c.client if params.Client != nil { + // Use the HTTP client scoped to the request. client = params.Client } @@ -163,19 +111,46 @@ func (c *Caller) CallRaw(ctx context.Context, params *CallRawParams) (*http.Resp return nil, err } + // Close the response body after we're done. + defer resp.Body.Close() + // Check if the call was cancelled before we return the error // associated with the call and/or unmarshal the response data. if err := ctx.Err(); err != nil { - defer resp.Body.Close() return nil, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { - defer resp.Body.Close() return nil, decodeError(resp, params.ErrorDecoder) } - return resp, nil + // Mutate the response parameter in-place. + if params.Response != nil { + if writer, ok := params.Response.(io.Writer); ok { + _, err = io.Copy(writer, resp.Body) + } else { + err = json.NewDecoder(resp.Body).Decode(params.Response) + } + if err != nil { + if err == io.EOF { + if params.ResponseIsOptional { + // The response is optional, so we should ignore the + // io.EOF error + return &CallResponse{ + StatusCode: resp.StatusCode, + Header: resp.Header, + }, nil + } + return nil, fmt.Errorf("expected a %T response, but the server responded with nothing", params.Response) + } + return nil, err + } + } + + return &CallResponse{ + StatusCode: resp.StatusCode, + Header: resp.Header, + }, nil } // buildURL constructs the final URL by appending the given query parameters (if any). @@ -250,7 +225,7 @@ func decodeError(response *http.Response, errorDecoder ErrorDecoder) error { // This endpoint has custom errors, so we'll // attempt to unmarshal the error into a structured // type based on the status code. - return errorDecoder(response.StatusCode, response.Body) + return errorDecoder(response.StatusCode, response.Header, response.Body) } // This endpoint doesn't have any custom error // types, so we just read the body as-is, and @@ -263,9 +238,9 @@ func decodeError(response *http.Response, errorDecoder ErrorDecoder) error { // The error didn't have a response body, // so all we can do is return an error // with the status code. - return core.NewAPIError(response.StatusCode, nil) + return core.NewAPIError(response.StatusCode, response.Header, nil) } - return core.NewAPIError(response.StatusCode, errors.New(string(bytes))) + return core.NewAPIError(response.StatusCode, response.Header, errors.New(string(bytes))) } // isNil is used to determine if the request value is equal to nil (i.e. an interface diff --git a/seed/go-sdk/idempotency-headers/internal/caller_test.go b/seed/go-sdk/idempotency-headers/internal/caller_test.go index b60f8b7744d..f5f86b7d769 100644 --- a/seed/go-sdk/idempotency-headers/internal/caller_test.go +++ b/seed/go-sdk/idempotency-headers/internal/caller_test.go @@ -102,6 +102,7 @@ func TestCall(t *testing.T) { wantError: &NotFoundError{ APIError: core.NewAPIError( http.StatusNotFound, + http.Header{}, errors.New(`{"message":"ID \"404\" not found"}`), ), }, @@ -115,6 +116,7 @@ func TestCall(t *testing.T) { giveRequest: nil, wantError: core.NewAPIError( http.StatusBadRequest, + http.Header{}, errors.New("invalid request"), ), }, @@ -140,6 +142,7 @@ func TestCall(t *testing.T) { }, wantError: core.NewAPIError( http.StatusInternalServerError, + http.Header{}, errors.New("failed to process request"), ), }, @@ -212,7 +215,7 @@ func TestCall(t *testing.T) { }, ) var response *Response - err := caller.Call( + _, err := caller.Call( context.Background(), &CallParams{ URL: server.URL + test.givePathSuffix, @@ -236,67 +239,6 @@ func TestCall(t *testing.T) { } } -func TestCallRaw(t *testing.T) { - tests := []*TestCase{ - { - description: "HEAD success", - giveMethod: http.MethodHead, - giveHeader: http.Header{ - "X-API-Status": []string{"success"}, - }, - wantHeaders: http.Header{ - "Content-Length": []string{"250"}, - "Date": []string{"1970-01-01"}, - }, - }, - } - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { - server := httptest.NewServer( - http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, test.giveMethod, r.Method) - for header, value := range test.giveHeader { - assert.Equal(t, value, r.Header.Values(header)) - } - for header, values := range test.wantHeaders { - for _, value := range values { - w.Header().Add(header, value) - } - } - w.WriteHeader(http.StatusOK) - }, - ), - ) - defer server.Close() - - caller := NewCaller( - &CallerParams{ - Client: server.Client(), - }, - ) - response, err := caller.CallRaw( - context.Background(), - &CallRawParams{ - URL: server.URL + test.givePathSuffix, - Method: test.giveMethod, - Headers: test.giveHeader, - BodyProperties: test.giveBodyProperties, - QueryParameters: test.giveQueryParams, - Request: test.giveRequest, - ErrorDecoder: test.giveErrorDecoder, - }, - ) - if test.wantError != nil { - assert.EqualError(t, err, test.wantError.Error()) - return - } - require.NoError(t, err) - assert.Equal(t, test.wantHeaders, response.Header) - }) - } -} - func TestMergeHeaders(t *testing.T) { t.Run("both empty", func(t *testing.T) { merged := MergeHeaders(make(http.Header), make(http.Header)) @@ -432,13 +374,13 @@ func newTestServer(t *testing.T, tc *TestCase) *httptest.Server { } // newTestErrorDecoder returns an error decoder suitable for tests. -func newTestErrorDecoder(t *testing.T) func(int, io.Reader) error { - return func(statusCode int, body io.Reader) error { +func newTestErrorDecoder(t *testing.T) func(int, http.Header, io.Reader) error { + return func(statusCode int, header http.Header, body io.Reader) error { raw, err := io.ReadAll(body) require.NoError(t, err) var ( - apiError = core.NewAPIError(statusCode, errors.New(string(raw))) + apiError = core.NewAPIError(statusCode, header, errors.New(string(raw))) decoder = json.NewDecoder(bytes.NewReader(raw)) ) if statusCode == http.StatusNotFound { diff --git a/seed/go-sdk/idempotency-headers/internal/error_decoder.go b/seed/go-sdk/idempotency-headers/internal/error_decoder.go index af59b517923..577eda5da4e 100644 --- a/seed/go-sdk/idempotency-headers/internal/error_decoder.go +++ b/seed/go-sdk/idempotency-headers/internal/error_decoder.go @@ -6,26 +6,28 @@ import ( "errors" "fmt" "io" + "net/http" "github.com/idempotency-headers/fern/core" ) // ErrorDecoder decodes *http.Response errors and returns a // typed API error (e.g. *core.APIError). -type ErrorDecoder func(statusCode int, body io.Reader) error +type ErrorDecoder func(statusCode int, header http.Header, body io.Reader) error // ErrorCodes maps HTTP status codes to error constructors. type ErrorCodes map[int]func(*core.APIError) error // NewErrorDecoder returns a new ErrorDecoder backed by the given error codes. func NewErrorDecoder(errorCodes ErrorCodes) ErrorDecoder { - return func(statusCode int, body io.Reader) error { + return func(statusCode int, header http.Header, body io.Reader) error { raw, err := io.ReadAll(body) if err != nil { return fmt.Errorf("failed to read error from response body: %w", err) } apiError := core.NewAPIError( statusCode, + header, errors.New(string(raw)), ) newErrorFunc, ok := errorCodes[statusCode] diff --git a/seed/go-sdk/idempotency-headers/internal/error_decoder_test.go b/seed/go-sdk/idempotency-headers/internal/error_decoder_test.go index a37c3765d1b..20add929f7d 100644 --- a/seed/go-sdk/idempotency-headers/internal/error_decoder_test.go +++ b/seed/go-sdk/idempotency-headers/internal/error_decoder_test.go @@ -21,35 +21,39 @@ func TestErrorDecoder(t *testing.T) { tests := []struct { description string giveStatusCode int + giveHeader http.Header giveBody string wantError error }{ { description: "unrecognized status code", giveStatusCode: http.StatusInternalServerError, + giveHeader: http.Header{}, giveBody: "Internal Server Error", - wantError: core.NewAPIError(http.StatusInternalServerError, errors.New("Internal Server Error")), + wantError: core.NewAPIError(http.StatusInternalServerError, http.Header{}, errors.New("Internal Server Error")), }, { description: "not found with valid JSON", giveStatusCode: http.StatusNotFound, + giveHeader: http.Header{}, giveBody: `{"message": "Resource not found"}`, wantError: &NotFoundError{ - APIError: core.NewAPIError(http.StatusNotFound, errors.New(`{"message": "Resource not found"}`)), + APIError: core.NewAPIError(http.StatusNotFound, http.Header{}, errors.New(`{"message": "Resource not found"}`)), Message: "Resource not found", }, }, { description: "not found with invalid JSON", giveStatusCode: http.StatusNotFound, + giveHeader: http.Header{}, giveBody: `Resource not found`, - wantError: core.NewAPIError(http.StatusNotFound, errors.New("Resource not found")), + wantError: core.NewAPIError(http.StatusNotFound, http.Header{}, errors.New("Resource not found")), }, } for _, tt := range tests { t.Run(tt.description, func(t *testing.T) { - assert.Equal(t, tt.wantError, decoder(tt.giveStatusCode, bytes.NewReader([]byte(tt.giveBody)))) + assert.Equal(t, tt.wantError, decoder(tt.giveStatusCode, tt.giveHeader, bytes.NewReader([]byte(tt.giveBody)))) }) } } diff --git a/seed/go-sdk/idempotency-headers/internal/retrier_test.go b/seed/go-sdk/idempotency-headers/internal/retrier_test.go index 91541833cf1..c4a21b0a4d0 100644 --- a/seed/go-sdk/idempotency-headers/internal/retrier_test.go +++ b/seed/go-sdk/idempotency-headers/internal/retrier_test.go @@ -107,7 +107,7 @@ func TestRetrier(t *testing.T) { ) var response *Response - err := caller.Call( + _, err := caller.Call( context.Background(), &CallParams{ URL: server.URL, diff --git a/seed/go-sdk/idempotency-headers/payment/client.go b/seed/go-sdk/idempotency-headers/payment/client.go index f744e8b2db9..9096bb6297d 100644 --- a/seed/go-sdk/idempotency-headers/payment/client.go +++ b/seed/go-sdk/idempotency-headers/payment/client.go @@ -50,7 +50,7 @@ func (c *Client) Create( ) var response uuid.UUID - if err := c.caller.Call( + if _, err := c.caller.Call( ctx, &internal.CallParams{ URL: endpointURL, @@ -89,7 +89,7 @@ func (c *Client) Delete( options.ToHeader(), ) - if err := c.caller.Call( + if _, err := c.caller.Call( ctx, &internal.CallParams{ URL: endpointURL, diff --git a/seed/go-sdk/idempotency-headers/payment/raw_client.go b/seed/go-sdk/idempotency-headers/payment/raw_client.go new file mode 100644 index 00000000000..8c4f6cd6154 --- /dev/null +++ b/seed/go-sdk/idempotency-headers/payment/raw_client.go @@ -0,0 +1,113 @@ +package payment + +import ( + context "context" + uuid "github.com/google/uuid" + fern "github.com/idempotency-headers/fern" + core "github.com/idempotency-headers/fern/core" + internal "github.com/idempotency-headers/fern/internal" + option "github.com/idempotency-headers/fern/option" + http "net/http" +) + +type RawClient struct { + baseURL string + caller *internal.Caller + header http.Header +} + +func NewRawClient(opts ...option.RequestOption) *RawClient { + options := core.NewRequestOptions(opts...) + return &RawClient{ + baseURL: options.BaseURL, + caller: internal.NewCaller( + &internal.CallerParams{ + Client: options.HTTPClient, + MaxAttempts: options.MaxAttempts, + }, + ), + header: options.ToHeader(), + } +} + +func (r RawClient) Create( + ctx context.Context, + request *fern.CreatePaymentRequest, + opts ...option.IdempotentRequestOption, +) (*core.Response[uuid.UUID], error) { + options := core.NewIdempotentRequestOptions(opts...) + baseURL := internal.ResolveBaseURL( + options.BaseURL, + r.baseURL, + "", + ) + endpointURL := baseURL + "/payment" + headers := internal.MergeHeaders( + r.header.Clone(), + options.ToHeader(), + ) + var response uuid.UUID + raw, err := r.caller.Call( + ctx, + &internal.CallParams{ + URL: endpointURL, + Method: http.MethodPost, + Headers: headers, + MaxAttempts: options.MaxAttempts, + BodyProperties: options.BodyProperties, + QueryParameters: options.QueryParameters, + Client: options.HTTPClient, + Request: request, + Response: &response, + }, + ) + if err != nil { + return nil, err + } + return &core.Response[uuid.UUID]{ + StatusCode: raw.StatusCode, + Header: raw.Header, + Body: response, + }, nil +} + +func (r RawClient) Delete( + ctx context.Context, + paymentId string, + opts ...option.RequestOption, +) (*core.Response[any], error) { + options := core.NewRequestOptions(opts...) + baseURL := internal.ResolveBaseURL( + options.BaseURL, + r.baseURL, + "", + ) + endpointURL := internal.EncodeURL( + baseURL+"/payment/%v", + paymentId, + ) + headers := internal.MergeHeaders( + r.header.Clone(), + options.ToHeader(), + ) + raw, err := r.caller.Call( + ctx, + &internal.CallParams{ + URL: endpointURL, + Method: http.MethodDelete, + Headers: headers, + MaxAttempts: options.MaxAttempts, + BodyProperties: options.BodyProperties, + QueryParameters: options.QueryParameters, + Client: options.HTTPClient, + }, + ) + if err != nil { + return nil, err + } + return &core.Response[any]{ + StatusCode: raw.StatusCode, + Header: raw.Header, + Body: nil, + }, nil +} diff --git a/seed/go-sdk/imdb/no-custom-config/core/api_error.go b/seed/go-sdk/imdb/no-custom-config/core/api_error.go index dc4190ca1cd..6168388541b 100644 --- a/seed/go-sdk/imdb/no-custom-config/core/api_error.go +++ b/seed/go-sdk/imdb/no-custom-config/core/api_error.go @@ -1,19 +1,24 @@ package core -import "fmt" +import ( + "fmt" + "net/http" +) // APIError is a lightweight wrapper around the standard error // interface that preserves the status code from the RPC, if any. type APIError struct { err error - StatusCode int `json:"-"` + StatusCode int `json:"-"` + Header http.Header `json:"-"` } // NewAPIError constructs a new API error. -func NewAPIError(statusCode int, err error) *APIError { +func NewAPIError(statusCode int, header http.Header, err error) *APIError { return &APIError{ err: err, + Header: header, StatusCode: statusCode, } } diff --git a/seed/go-sdk/imdb/no-custom-config/core/http.go b/seed/go-sdk/imdb/no-custom-config/core/http.go index b553350b84e..92c43569294 100644 --- a/seed/go-sdk/imdb/no-custom-config/core/http.go +++ b/seed/go-sdk/imdb/no-custom-config/core/http.go @@ -6,3 +6,10 @@ import "net/http" type HTTPClient interface { Do(*http.Request) (*http.Response, error) } + +// Response is an HTTP response from an HTTP client. +type Response[T any] struct { + StatusCode int + Header http.Header + Body T +} diff --git a/seed/go-sdk/imdb/no-custom-config/go.mod b/seed/go-sdk/imdb/no-custom-config/go.mod index 4ae3d369ddb..2cc724b507d 100644 --- a/seed/go-sdk/imdb/no-custom-config/go.mod +++ b/seed/go-sdk/imdb/no-custom-config/go.mod @@ -1,6 +1,6 @@ module github.com/imdb/fern -go 1.13 +go 1.18 require ( github.com/google/uuid v1.4.0 diff --git a/seed/go-sdk/imdb/no-custom-config/imdb/client.go b/seed/go-sdk/imdb/no-custom-config/imdb/client.go index c2ce646b66d..b6fbd3f4b1e 100644 --- a/seed/go-sdk/imdb/no-custom-config/imdb/client.go +++ b/seed/go-sdk/imdb/no-custom-config/imdb/client.go @@ -12,6 +12,9 @@ import ( ) type Client struct { + // WithRawResponse can be used to receive raw HTTP response data, such as headers. + WithRawResponse *RawClient + baseURL string caller *internal.Caller header http.Header @@ -20,6 +23,7 @@ type Client struct { func NewClient(opts ...option.RequestOption) *Client { options := core.NewRequestOptions(opts...) return &Client{ + WithRawResponse: NewRawClient(opts...), baseURL: options.BaseURL, caller: internal.NewCaller( &internal.CallerParams{ @@ -37,36 +41,11 @@ func (c *Client) CreateMovie( request *fern.CreateMovieRequest, opts ...option.RequestOption, ) (fern.MovieId, error) { - options := core.NewRequestOptions(opts...) - baseURL := internal.ResolveBaseURL( - options.BaseURL, - c.baseURL, - "", - ) - endpointURL := baseURL + "/movies/create-movie" - headers := internal.MergeHeaders( - c.header.Clone(), - options.ToHeader(), - ) - - var response fern.MovieId - if err := c.caller.Call( - ctx, - &internal.CallParams{ - URL: endpointURL, - Method: http.MethodPost, - Headers: headers, - MaxAttempts: options.MaxAttempts, - BodyProperties: options.BodyProperties, - QueryParameters: options.QueryParameters, - Client: options.HTTPClient, - Request: request, - Response: &response, - }, - ); err != nil { + response, err := c.WithRawResponse.CreateMovie(ctx, request, opts...) + if err != nil { return "", err } - return response, nil + return response.Body, nil } func (c *Client) GetMovie( @@ -74,44 +53,9 @@ func (c *Client) GetMovie( movieId fern.MovieId, opts ...option.RequestOption, ) (*fern.Movie, error) { - options := core.NewRequestOptions(opts...) - baseURL := internal.ResolveBaseURL( - options.BaseURL, - c.baseURL, - "", - ) - endpointURL := internal.EncodeURL( - baseURL+"/movies/%v", - movieId, - ) - headers := internal.MergeHeaders( - c.header.Clone(), - options.ToHeader(), - ) - errorCodes := internal.ErrorCodes{ - 404: func(apiError *core.APIError) error { - return &fern.MovieDoesNotExistError{ - APIError: apiError, - } - }, - } - - var response *fern.Movie - if err := c.caller.Call( - ctx, - &internal.CallParams{ - URL: endpointURL, - Method: http.MethodGet, - Headers: headers, - MaxAttempts: options.MaxAttempts, - BodyProperties: options.BodyProperties, - QueryParameters: options.QueryParameters, - Client: options.HTTPClient, - Response: &response, - ErrorDecoder: internal.NewErrorDecoder(errorCodes), - }, - ); err != nil { + response, err := c.WithRawResponse.GetMovie(ctx, movieId, opts...) + if err != nil { return nil, err } - return response, nil + return response.Body, nil } diff --git a/seed/go-sdk/imdb/no-custom-config/imdb/raw_client.go b/seed/go-sdk/imdb/no-custom-config/imdb/raw_client.go new file mode 100644 index 00000000000..e963bb55bc3 --- /dev/null +++ b/seed/go-sdk/imdb/no-custom-config/imdb/raw_client.go @@ -0,0 +1,127 @@ +// Code generated by Fern. DO NOT EDIT. + +package imdb + +import ( + context "context" + fern "github.com/imdb/fern" + core "github.com/imdb/fern/core" + internal "github.com/imdb/fern/internal" + option "github.com/imdb/fern/option" + http "net/http" +) + +type RawClient struct { + baseURL string + caller *internal.Caller + header http.Header +} + +func NewRawClient(opts ...option.RequestOption) *RawClient { + options := core.NewRequestOptions(opts...) + return &RawClient{ + baseURL: options.BaseURL, + caller: internal.NewCaller( + &internal.CallerParams{ + Client: options.HTTPClient, + MaxAttempts: options.MaxAttempts, + }, + ), + header: options.ToHeader(), + } +} + +// Add a movie to the database using the movies/* /... path. +func (r *RawClient) CreateMovie( + ctx context.Context, + request *fern.CreateMovieRequest, + opts ...option.RequestOption, +) (*core.Response[fern.MovieId], error) { + options := core.NewRequestOptions(opts...) + baseURL := internal.ResolveBaseURL( + options.BaseURL, + r.baseURL, + "", + ) + endpointURL := baseURL + "/movies/create-movie" + headers := internal.MergeHeaders( + r.header.Clone(), + options.ToHeader(), + ) + + var response fern.MovieId + raw, err := r.caller.Call( + ctx, + &internal.CallParams{ + URL: endpointURL, + Method: http.MethodPost, + Headers: headers, + MaxAttempts: options.MaxAttempts, + BodyProperties: options.BodyProperties, + QueryParameters: options.QueryParameters, + Client: options.HTTPClient, + Request: request, + Response: &response, + }, + ) + if err != nil { + return nil, err + } + return &core.Response[fern.MovieId]{ + StatusCode: raw.StatusCode, + Header: raw.Header, + Body: response, + }, nil +} + +func (r *RawClient) GetMovie( + ctx context.Context, + movieId fern.MovieId, + opts ...option.RequestOption, +) (*core.Response[*fern.Movie], error) { + options := core.NewRequestOptions(opts...) + baseURL := internal.ResolveBaseURL( + options.BaseURL, + r.baseURL, + "", + ) + endpointURL := internal.EncodeURL( + baseURL+"/movies/%v", + movieId, + ) + headers := internal.MergeHeaders( + r.header.Clone(), + options.ToHeader(), + ) + errorCodes := internal.ErrorCodes{ + 404: func(apiError *core.APIError) error { + return &fern.MovieDoesNotExistError{ + APIError: apiError, + } + }, + } + + var response *fern.Movie + raw, err := r.caller.Call( + ctx, + &internal.CallParams{ + URL: endpointURL, + Method: http.MethodGet, + Headers: headers, + MaxAttempts: options.MaxAttempts, + BodyProperties: options.BodyProperties, + QueryParameters: options.QueryParameters, + Client: options.HTTPClient, + Response: &response, + ErrorDecoder: internal.NewErrorDecoder(errorCodes), + }, + ) + if err != nil { + return nil, err + } + return &core.Response[*fern.Movie]{ + StatusCode: raw.StatusCode, + Header: raw.Header, + Body: response, + }, nil +} diff --git a/seed/go-sdk/imdb/no-custom-config/internal/caller.go b/seed/go-sdk/imdb/no-custom-config/internal/caller.go index 25c68c7568e..49a475bf422 100644 --- a/seed/go-sdk/imdb/no-custom-config/internal/caller.go +++ b/seed/go-sdk/imdb/no-custom-config/internal/caller.go @@ -64,67 +64,14 @@ type CallParams struct { ErrorDecoder ErrorDecoder } -// Call issues an API call according to the given call parameters. -func (c *Caller) Call(ctx context.Context, params *CallParams) error { - resp, err := c.CallRaw( - ctx, - &CallRawParams{ - URL: params.URL, - Method: params.Method, - MaxAttempts: params.MaxAttempts, - Headers: params.Headers, - BodyProperties: params.BodyProperties, - QueryParameters: params.QueryParameters, - Client: params.Client, - Request: params.Request, - ErrorDecoder: params.ErrorDecoder, - }, - ) - if err != nil { - return err - } - - // Close the response body after we're done. - defer resp.Body.Close() - - if params.Response != nil { - if writer, ok := params.Response.(io.Writer); ok { - _, err = io.Copy(writer, resp.Body) - } else { - err = json.NewDecoder(resp.Body).Decode(params.Response) - } - if err != nil { - if err == io.EOF { - if params.ResponseIsOptional { - // The response is optional, so we should ignore the - // io.EOF error - return nil - } - return fmt.Errorf("expected a %T response, but the server responded with nothing", params.Response) - } - return err - } - } - - return nil -} - -// CallRawParams represents the parameters used to issue an API call. -type CallRawParams struct { - URL string - Method string - MaxAttempts uint - Headers http.Header - BodyProperties map[string]interface{} - QueryParameters url.Values - Client core.HTTPClient - Request interface{} - ErrorDecoder ErrorDecoder +// CallResponse is a parsed HTTP response from an API call. +type CallResponse struct { + StatusCode int + Header http.Header } -// CallRaw issues an API call according to the given call parameters and returns the raw HTTP response. -// The caller is responsible for closing the response body. -func (c *Caller) CallRaw(ctx context.Context, params *CallRawParams) (*http.Response, error) { +// Call issues an API call according to the given call parameters. +func (c *Caller) Call(ctx context.Context, params *CallParams) (*CallResponse, error) { url := buildURL(params.URL, params.QueryParameters) req, err := newRequest( ctx, @@ -145,6 +92,7 @@ func (c *Caller) CallRaw(ctx context.Context, params *CallRawParams) (*http.Resp client := c.client if params.Client != nil { + // Use the HTTP client scoped to the request. client = params.Client } @@ -163,19 +111,46 @@ func (c *Caller) CallRaw(ctx context.Context, params *CallRawParams) (*http.Resp return nil, err } + // Close the response body after we're done. + defer resp.Body.Close() + // Check if the call was cancelled before we return the error // associated with the call and/or unmarshal the response data. if err := ctx.Err(); err != nil { - defer resp.Body.Close() return nil, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { - defer resp.Body.Close() return nil, decodeError(resp, params.ErrorDecoder) } - return resp, nil + // Mutate the response parameter in-place. + if params.Response != nil { + if writer, ok := params.Response.(io.Writer); ok { + _, err = io.Copy(writer, resp.Body) + } else { + err = json.NewDecoder(resp.Body).Decode(params.Response) + } + if err != nil { + if err == io.EOF { + if params.ResponseIsOptional { + // The response is optional, so we should ignore the + // io.EOF error + return &CallResponse{ + StatusCode: resp.StatusCode, + Header: resp.Header, + }, nil + } + return nil, fmt.Errorf("expected a %T response, but the server responded with nothing", params.Response) + } + return nil, err + } + } + + return &CallResponse{ + StatusCode: resp.StatusCode, + Header: resp.Header, + }, nil } // buildURL constructs the final URL by appending the given query parameters (if any). @@ -250,7 +225,7 @@ func decodeError(response *http.Response, errorDecoder ErrorDecoder) error { // This endpoint has custom errors, so we'll // attempt to unmarshal the error into a structured // type based on the status code. - return errorDecoder(response.StatusCode, response.Body) + return errorDecoder(response.StatusCode, response.Header, response.Body) } // This endpoint doesn't have any custom error // types, so we just read the body as-is, and @@ -263,9 +238,9 @@ func decodeError(response *http.Response, errorDecoder ErrorDecoder) error { // The error didn't have a response body, // so all we can do is return an error // with the status code. - return core.NewAPIError(response.StatusCode, nil) + return core.NewAPIError(response.StatusCode, response.Header, nil) } - return core.NewAPIError(response.StatusCode, errors.New(string(bytes))) + return core.NewAPIError(response.StatusCode, response.Header, errors.New(string(bytes))) } // isNil is used to determine if the request value is equal to nil (i.e. an interface diff --git a/seed/go-sdk/imdb/no-custom-config/internal/caller_test.go b/seed/go-sdk/imdb/no-custom-config/internal/caller_test.go index b346bca9f84..fc5e597cc18 100644 --- a/seed/go-sdk/imdb/no-custom-config/internal/caller_test.go +++ b/seed/go-sdk/imdb/no-custom-config/internal/caller_test.go @@ -102,6 +102,7 @@ func TestCall(t *testing.T) { wantError: &NotFoundError{ APIError: core.NewAPIError( http.StatusNotFound, + http.Header{}, errors.New(`{"message":"ID \"404\" not found"}`), ), }, @@ -115,6 +116,7 @@ func TestCall(t *testing.T) { giveRequest: nil, wantError: core.NewAPIError( http.StatusBadRequest, + http.Header{}, errors.New("invalid request"), ), }, @@ -140,6 +142,7 @@ func TestCall(t *testing.T) { }, wantError: core.NewAPIError( http.StatusInternalServerError, + http.Header{}, errors.New("failed to process request"), ), }, @@ -212,7 +215,7 @@ func TestCall(t *testing.T) { }, ) var response *Response - err := caller.Call( + _, err := caller.Call( context.Background(), &CallParams{ URL: server.URL + test.givePathSuffix, @@ -236,67 +239,6 @@ func TestCall(t *testing.T) { } } -func TestCallRaw(t *testing.T) { - tests := []*TestCase{ - { - description: "HEAD success", - giveMethod: http.MethodHead, - giveHeader: http.Header{ - "X-API-Status": []string{"success"}, - }, - wantHeaders: http.Header{ - "Content-Length": []string{"250"}, - "Date": []string{"1970-01-01"}, - }, - }, - } - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { - server := httptest.NewServer( - http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, test.giveMethod, r.Method) - for header, value := range test.giveHeader { - assert.Equal(t, value, r.Header.Values(header)) - } - for header, values := range test.wantHeaders { - for _, value := range values { - w.Header().Add(header, value) - } - } - w.WriteHeader(http.StatusOK) - }, - ), - ) - defer server.Close() - - caller := NewCaller( - &CallerParams{ - Client: server.Client(), - }, - ) - response, err := caller.CallRaw( - context.Background(), - &CallRawParams{ - URL: server.URL + test.givePathSuffix, - Method: test.giveMethod, - Headers: test.giveHeader, - BodyProperties: test.giveBodyProperties, - QueryParameters: test.giveQueryParams, - Request: test.giveRequest, - ErrorDecoder: test.giveErrorDecoder, - }, - ) - if test.wantError != nil { - assert.EqualError(t, err, test.wantError.Error()) - return - } - require.NoError(t, err) - assert.Equal(t, test.wantHeaders, response.Header) - }) - } -} - func TestMergeHeaders(t *testing.T) { t.Run("both empty", func(t *testing.T) { merged := MergeHeaders(make(http.Header), make(http.Header)) @@ -432,13 +374,13 @@ func newTestServer(t *testing.T, tc *TestCase) *httptest.Server { } // newTestErrorDecoder returns an error decoder suitable for tests. -func newTestErrorDecoder(t *testing.T) func(int, io.Reader) error { - return func(statusCode int, body io.Reader) error { +func newTestErrorDecoder(t *testing.T) func(int, http.Header, io.Reader) error { + return func(statusCode int, header http.Header, body io.Reader) error { raw, err := io.ReadAll(body) require.NoError(t, err) var ( - apiError = core.NewAPIError(statusCode, errors.New(string(raw))) + apiError = core.NewAPIError(statusCode, header, errors.New(string(raw))) decoder = json.NewDecoder(bytes.NewReader(raw)) ) if statusCode == http.StatusNotFound { diff --git a/seed/go-sdk/imdb/no-custom-config/internal/error_decoder.go b/seed/go-sdk/imdb/no-custom-config/internal/error_decoder.go index 311f35eda26..c22bbf0b399 100644 --- a/seed/go-sdk/imdb/no-custom-config/internal/error_decoder.go +++ b/seed/go-sdk/imdb/no-custom-config/internal/error_decoder.go @@ -6,26 +6,28 @@ import ( "errors" "fmt" "io" + "net/http" "github.com/imdb/fern/core" ) // ErrorDecoder decodes *http.Response errors and returns a // typed API error (e.g. *core.APIError). -type ErrorDecoder func(statusCode int, body io.Reader) error +type ErrorDecoder func(statusCode int, header http.Header, body io.Reader) error // ErrorCodes maps HTTP status codes to error constructors. type ErrorCodes map[int]func(*core.APIError) error // NewErrorDecoder returns a new ErrorDecoder backed by the given error codes. func NewErrorDecoder(errorCodes ErrorCodes) ErrorDecoder { - return func(statusCode int, body io.Reader) error { + return func(statusCode int, header http.Header, body io.Reader) error { raw, err := io.ReadAll(body) if err != nil { return fmt.Errorf("failed to read error from response body: %w", err) } apiError := core.NewAPIError( statusCode, + header, errors.New(string(raw)), ) newErrorFunc, ok := errorCodes[statusCode] diff --git a/seed/go-sdk/imdb/no-custom-config/internal/error_decoder_test.go b/seed/go-sdk/imdb/no-custom-config/internal/error_decoder_test.go index 77febc34311..6b5265c92f9 100644 --- a/seed/go-sdk/imdb/no-custom-config/internal/error_decoder_test.go +++ b/seed/go-sdk/imdb/no-custom-config/internal/error_decoder_test.go @@ -21,35 +21,39 @@ func TestErrorDecoder(t *testing.T) { tests := []struct { description string giveStatusCode int + giveHeader http.Header giveBody string wantError error }{ { description: "unrecognized status code", giveStatusCode: http.StatusInternalServerError, + giveHeader: http.Header{}, giveBody: "Internal Server Error", - wantError: core.NewAPIError(http.StatusInternalServerError, errors.New("Internal Server Error")), + wantError: core.NewAPIError(http.StatusInternalServerError, http.Header{}, errors.New("Internal Server Error")), }, { description: "not found with valid JSON", giveStatusCode: http.StatusNotFound, + giveHeader: http.Header{}, giveBody: `{"message": "Resource not found"}`, wantError: &NotFoundError{ - APIError: core.NewAPIError(http.StatusNotFound, errors.New(`{"message": "Resource not found"}`)), + APIError: core.NewAPIError(http.StatusNotFound, http.Header{}, errors.New(`{"message": "Resource not found"}`)), Message: "Resource not found", }, }, { description: "not found with invalid JSON", giveStatusCode: http.StatusNotFound, + giveHeader: http.Header{}, giveBody: `Resource not found`, - wantError: core.NewAPIError(http.StatusNotFound, errors.New("Resource not found")), + wantError: core.NewAPIError(http.StatusNotFound, http.Header{}, errors.New("Resource not found")), }, } for _, tt := range tests { t.Run(tt.description, func(t *testing.T) { - assert.Equal(t, tt.wantError, decoder(tt.giveStatusCode, bytes.NewReader([]byte(tt.giveBody)))) + assert.Equal(t, tt.wantError, decoder(tt.giveStatusCode, tt.giveHeader, bytes.NewReader([]byte(tt.giveBody)))) }) } } diff --git a/seed/go-sdk/imdb/no-custom-config/internal/retrier_test.go b/seed/go-sdk/imdb/no-custom-config/internal/retrier_test.go index 16e7b361c0c..7fa4557fcda 100644 --- a/seed/go-sdk/imdb/no-custom-config/internal/retrier_test.go +++ b/seed/go-sdk/imdb/no-custom-config/internal/retrier_test.go @@ -107,7 +107,7 @@ func TestRetrier(t *testing.T) { ) var response *Response - err := caller.Call( + _, err := caller.Call( context.Background(), &CallParams{ URL: server.URL,